From 420a328e8f49472ff2f9501f3aa8eecb96af935b Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 27 Nov 2023 11:20:47 -0300 Subject: [PATCH 01/17] Rename registries --- README.md | 4 +-- retrack/__init__.py | 4 +-- retrack/engine/parser.py | 41 +++++++++++++++++------------- retrack/nodes/__init__.py | 10 ++++++-- tests/test_nodes/test_csv_table.py | 8 +++--- 5 files changed, 40 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 24e5172..a5f820c 100644 --- a/README.md +++ b/README.md @@ -130,10 +130,10 @@ After creating the custom node, you need to register it in the nodes registry an import retrack # Register the custom node -retrack.component_registry.register_node("sum", SumNode) +retrack.nodes_registry.register_node("sum", SumNode) # Parse the rule/model -parser = Parser(rule, component_registry=retrack.component_registry) +parser = Parser(rule, nodes_registry=retrack.nodes_registry) ``` ## Contributing diff --git a/retrack/__init__.py b/retrack/__init__.py index fb11807..ede84ed 100644 --- a/retrack/__init__.py +++ b/retrack/__init__.py @@ -1,6 +1,6 @@ from retrack.engine.parser import Parser from retrack.engine.runner import Runner -from retrack.nodes import registry as component_registry +from retrack.nodes import registry as nodes_registry from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel __all__ = [ @@ -9,5 +9,5 @@ "BaseNode", "InputConnectionModel", "OutputConnectionModel", - "component_registry", + "nodes_registry", ] diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index 0e47045..ad32433 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -13,22 +13,22 @@ class Parser: def __init__( self, graph_data: dict, - component_registry: Registry = nodes.registry(), - dynamic_registry: Registry = nodes.dynamic_registry(), + nodes_registry: Registry = nodes.registry(), + dynamic_nodes_registry: Registry = nodes.dynamic_nodes_registry(), validator_registry: Registry = validators.registry(), raise_if_null_version: bool = False, validate_version: bool = True, ): self.__graph_data = graph_data + self.__components_registry = Registry() self._execution_order = None - self.__components = {} self.__edges = None self._raise_if_null_version = raise_if_null_version self._validate_version = validate_version self._check_input_data(self.graph_data) - self._set_components(component_registry, dynamic_registry) + self._set_components(nodes_registry, dynamic_nodes_registry) self._set_edges() self._validate_graph(validator_registry) @@ -47,6 +47,14 @@ def graph_data(self) -> dict: def version(self) -> str: return self._version + @property + def components_registry(self) -> Registry: + return self.__components_registry + + @property + def components(self) -> typing.Dict[str, nodes.BaseNode]: + return self.components_registry.data + @staticmethod def _check_input_data(data: dict): if not isinstance(data, dict): @@ -57,24 +65,21 @@ def _check_input_data(data: dict): raise ValueError("No nodes found in data") if not isinstance(data["nodes"], dict): raise TypeError( - "BaseNodes must be a dictionary. Instead got: " - + str(type(data["nodes"])) + "Nodes must be a dictionary. Instead got: " + str(type(data["nodes"])) ) @staticmethod def _check_node_name(node_name: str, node_id: str): if node_name is None: - raise ValueError(f"BaseNode {node_id} has no name") + raise ValueError(f"Node {node_id} has no name") if not isinstance(node_name, str): - raise TypeError(f"BaseNode {node_id} name must be a string") + raise TypeError(f"Node {node_id} name must be a string") - @property - def components(self) -> typing.Dict[str, nodes.BaseNode]: - return self.__components - - def _set_components(self, component_registry: Registry, dynamic_registry: Registry): + def _set_components( + self, nodes_registry: Registry, dynamic_nodes_registry: Registry + ): for node_id, node_metadata in self.graph_data["nodes"].items(): - if node_id in self.__components: + if node_id in self.components_registry: raise ValueError(f"Duplicate node id: {node_id}") node_name = node_metadata.get("name", None) @@ -82,17 +87,19 @@ def _set_components(self, component_registry: Registry, dynamic_registry: Regist node_name = node_name.lower() - node_factory = dynamic_registry.get(node_name) + node_factory = dynamic_nodes_registry.get(node_name) if node_factory is not None: validation_model = node_factory(**node_metadata) else: - validation_model = component_registry.get(node_name) + validation_model = nodes_registry.get(node_name) if validation_model is None: raise ValueError(f"Unknown node name: {node_name}") - self.__components[node_id] = validation_model(**node_metadata) + self.components_registry.register( + node_id, validation_model(**node_metadata) + ) @property def edges(self) -> typing.List[typing.Tuple[str, str]]: diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index 4d01912..e4accab 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -4,7 +4,7 @@ from retrack.nodes.contains import Contains from retrack.nodes.datetime import CurrentYear from retrack.nodes.dynamic import BaseDynamicNode -from retrack.nodes.dynamic import registry as dynamic_registry +from retrack.nodes.dynamic import registry as dynamic_nodes_registry from retrack.nodes.endswith import EndsWith from retrack.nodes.endswithany import EndsWithAny from retrack.nodes.inputs import Input @@ -52,4 +52,10 @@ def register(name: str, node: BaseNode) -> None: register("IntervalCatV0", IntervalCatV0) register("LowerCase", LowerCase) -__all__ = ["registry", "register", "BaseNode", "dynamic_registry", "BaseDynamicNode"] +__all__ = [ + "registry", + "register", + "BaseNode", + "dynamic_nodes_registry", + "BaseDynamicNode", +] diff --git a/tests/test_nodes/test_csv_table.py b/tests/test_nodes/test_csv_table.py index a9fbaa1..0bc6a12 100644 --- a/tests/test_nodes/test_csv_table.py +++ b/tests/test_nodes/test_csv_table.py @@ -2,7 +2,7 @@ import pydantic import pytest -from retrack.nodes import dynamic_registry +from retrack.nodes import dynamic_nodes_registry @pytest.fixture @@ -60,13 +60,13 @@ def csv_table_metadata(): def test_get_csv_table_factory(): - csv_table_factory = dynamic_registry().get("CSVTableV0") + csv_table_factory = dynamic_nodes_registry().get("CSVTableV0") assert callable(csv_table_factory) def test_create_model_from_factory(csv_table_metadata): - csv_table_factory = dynamic_registry().get("CSVTableV0") + csv_table_factory = dynamic_nodes_registry().get("CSVTableV0") CSVTableV0 = csv_table_factory(**csv_table_metadata) assert issubclass(CSVTableV0, pydantic.BaseModel) @@ -78,7 +78,7 @@ def test_create_model_from_factory(csv_table_metadata): def test_csv_table_run(csv_table_metadata): - csv_table_factory = dynamic_registry().get("CSVTableV0") + csv_table_factory = dynamic_nodes_registry().get("CSVTableV0") CSVTableV0 = csv_table_factory(**csv_table_metadata) model = CSVTableV0(**csv_table_metadata) From 496e98ff594ca38fc9cca6bdaaffd84c5518a489 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 27 Nov 2023 12:04:39 -0300 Subject: [PATCH 02/17] Add input nodes generation from connectors --- retrack/engine/parser.py | 10 ++++++++++ retrack/nodes/__init__.py | 2 ++ retrack/nodes/inputs.py | 8 ++++++++ 3 files changed, 20 insertions(+) diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index ad32433..f5155aa 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -4,6 +4,7 @@ from retrack import nodes, validators from retrack.utils.registry import Registry +from retrack.nodes.base import NodeKind import json from unidecode import unidecode @@ -38,6 +39,7 @@ def __init__( self._set_execution_order() self._set_indexes_by_memory_type_map() self._set_version() + self._set_input_nodes_from_connectors() @property def graph_data(self) -> dict: @@ -247,3 +249,11 @@ def _set_version(self): raise ValueError( f"Invalid version. Graph data has changed and the hash is different: {calculated_hash} != {file_version_hash}" ) + + def _set_input_nodes_from_connectors(self): + connector_nodes = self.get_by_kind(NodeKind.CONNECTOR) + + for connector_node in connector_nodes: + input_nodes = connector_node.generate_input_nodes() + for input_node in input_nodes: + self.components_registry.register(input_node.id, input_node) diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index e4accab..ead3882 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -30,6 +30,8 @@ def register(name: str, node: BaseNode) -> None: register("Input", Input) +register("Connector", Input) # virtual node +register("ConnectorV0", Input) # virtual node register("Start", Start) register("Constant", Constant) register("List", List) diff --git a/retrack/nodes/inputs.py b/retrack/nodes/inputs.py index 72db14f..3fb8aac 100644 --- a/retrack/nodes/inputs.py +++ b/retrack/nodes/inputs.py @@ -44,3 +44,11 @@ class Input(BaseNode): def kind(self) -> NodeKind: return NodeKind.INPUT + + +class BaseConnector(BaseNode): + def kind(self) -> NodeKind: + return NodeKind.CONNECTOR + + def generate_input_nodes(self) -> typing.List[Input]: + return [] From e1739e4e04273eff1a06fa45113fbedb09e9fe11 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Tue, 28 Nov 2023 10:51:09 -0300 Subject: [PATCH 03/17] Connector with virtual inputs --- retrack/engine/request_manager.py | 6 +- retrack/engine/runner.py | 9 +- retrack/nodes/__init__.py | 6 +- retrack/nodes/connectors.py | 20 ++ retrack/nodes/dynamic/flow.py | 4 +- retrack/nodes/inputs.py | 8 - tests/resources/connector-rule.json | 171 ++++++++++++++++++ .../rule-of-rules-with-connector.json | 122 +++++++++++++ tests/test_engine/test_runner.py | 30 +++ tests/test_nodes/test_connectors.py | 28 +++ 10 files changed, 387 insertions(+), 17 deletions(-) create mode 100644 retrack/nodes/connectors.py create mode 100644 tests/resources/connector-rule.json create mode 100644 tests/resources/rule-of-rules-with-connector.json create mode 100644 tests/test_nodes/test_connectors.py diff --git a/retrack/engine/request_manager.py b/retrack/engine/request_manager.py index 1007666..a3b155f 100644 --- a/retrack/engine/request_manager.py +++ b/retrack/engine/request_manager.py @@ -119,14 +119,14 @@ def __create_dataframe_model(self) -> pandera.DataFrameSchema: return pandera.DataFrameSchema( fields, index=pandera.Index(int), - strict=True, + # strict=True, coerce=True, ) def validate( self, payload: pd.DataFrame, - ) -> typing.List[pydantic.BaseModel]: + ) -> pd.DataFrame: """Validate the payload against the RequestManager's model Args: @@ -136,7 +136,7 @@ def validate( ValueError: If the RequestManager has no model Returns: - typing.List[pydantic.BaseModel]: The validated payload + pd.DataFrame: The validated payload """ if self.model is None: raise ValueError("No inputs found") diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 45aeb74..b6b52f5 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -17,11 +17,12 @@ def __init__(self, parser: Parser, name: str = None): self._parser = parser self._name = name self._internal_runners = {} + self._validated_payload = None self.reset() self._set_constants() self._set_input_columns() - self._request_manager = RequestManager(self._parser.get_by_kind(NodeKind.INPUT)) self._set_internal_runners() + self._request_manager = RequestManager(self._parser.get_by_kind(NodeKind.INPUT)) @classmethod def from_json(cls, data: typing.Union[str, dict], name: str = None, **kwargs): @@ -118,13 +119,13 @@ def _create_initial_state_from_payload( self, payload_df: pd.DataFrame ) -> pd.DataFrame: """Create initial state from payload. This is the first step of the runner.""" - validated_payload = self.request_manager.validate( + self._validated_payload = self.request_manager.validate( payload_df.reset_index(drop=True) ) state_df = pd.DataFrame([]) for node_id, input_name in self.input_columns.items(): - state_df[node_id] = validated_payload[input_name] + state_df[node_id] = self._validated_payload[input_name] state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan @@ -147,6 +148,8 @@ def __get_input_params( if node_id in self._internal_runners: input_params["runner"] = self._internal_runners[node_id] + for column_name, column_value in self._validated_payload.items(): + input_params[f"payload_{column_name}"] = column_value return input_params diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index ead3882..a521ef0 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -15,6 +15,7 @@ from retrack.nodes.start import Start from retrack.nodes.startswith import StartsWith from retrack.nodes.startswithany import StartsWithAny +from retrack.nodes.connectors import BaseConnector, VirtualConnector from retrack.nodes.lowercase import LowerCase from retrack.utils.registry import Registry @@ -30,8 +31,8 @@ def register(name: str, node: BaseNode) -> None: register("Input", Input) -register("Connector", Input) # virtual node -register("ConnectorV0", Input) # virtual node +register("Connector", VirtualConnector) # By default, Connector is an Input +register("ConnectorV0", VirtualConnector) # By default, Connector is an Input register("Start", Start) register("Constant", Constant) register("List", List) @@ -60,4 +61,5 @@ def register(name: str, node: BaseNode) -> None: "BaseNode", "dynamic_nodes_registry", "BaseDynamicNode", + "BaseConnector", ] diff --git a/retrack/nodes/connectors.py b/retrack/nodes/connectors.py new file mode 100644 index 0000000..d17e013 --- /dev/null +++ b/retrack/nodes/connectors.py @@ -0,0 +1,20 @@ +import typing + +from retrack.nodes.base import NodeKind +from retrack.nodes.inputs import Input + + +class VirtualConnector(Input): + def kind(self) -> NodeKind: + return NodeKind.INPUT + + def generate_input_nodes(self) -> typing.List[Input]: + return [] + + +class BaseConnector(VirtualConnector): + def kind(self) -> NodeKind: + return NodeKind.CONNECTOR + + def generate_input_nodes(self) -> typing.List[Input]: + raise NotImplementedError() diff --git a/retrack/nodes/dynamic/flow.py b/retrack/nodes/dynamic/flow.py index 36e70c6..de57c5b 100644 --- a/retrack/nodes/dynamic/flow.py +++ b/retrack/nodes/dynamic/flow.py @@ -50,7 +50,9 @@ def run(self, **kwargs) -> typing.Dict[str, typing.Any]: for name, value in kwargs.items(): if name.startswith("input_"): - inputs_in_kwargs[name[6:]] = value + inputs_in_kwargs[name[len("input_") :]] = value + elif name.startswith("payload_"): + inputs_in_kwargs[name[len("payload_") :]] = value response = runner.execute(pd.DataFrame(inputs_in_kwargs)) diff --git a/retrack/nodes/inputs.py b/retrack/nodes/inputs.py index 3fb8aac..72db14f 100644 --- a/retrack/nodes/inputs.py +++ b/retrack/nodes/inputs.py @@ -44,11 +44,3 @@ class Input(BaseNode): def kind(self) -> NodeKind: return NodeKind.INPUT - - -class BaseConnector(BaseNode): - def kind(self) -> NodeKind: - return NodeKind.CONNECTOR - - def generate_input_nodes(self) -> typing.List[Input]: - return [] diff --git a/tests/resources/connector-rule.json b/tests/resources/connector-rule.json new file mode 100644 index 0000000..f1ebf04 --- /dev/null +++ b/tests/resources/connector-rule.json @@ -0,0 +1,171 @@ +{ + "id": "demo@0.1.0", + "nodes": { + "0": { + "id": 0, + "data": {}, + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [ + { + "node": 5, + "input": "input_void", + "data": {} + } + ] + }, + "output_down_void": { + "connections": [ + { + "node": 10, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 36.999198273089995, + 200.90571277556646 + ], + "name": "Start" + }, + "4": { + "id": 4, + "data": { + "message": null + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 9, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 945.7727887309584, + 74.29575409540634 + ], + "name": "Output" + }, + "5": { + "id": 5, + "data": { + "name": "prediction", + "service": "service-name", + "identifier": "ml-model-name", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_up_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 9, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + 395.07074381698476, + -177.8616133049862 + ], + "name": "ConnectorV0" + }, + "9": { + "id": 9, + "data": { + "operator": "*" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 5, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 10, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 693.0432160700918, + 41.0955985999175 + ], + "name": "Math" + }, + "10": { + "id": 10, + "data": { + "name": "multiplier", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_down_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 9, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + 401.3893004299958, + 231.106029062272 + ], + "name": "Input" + } + }, + "version": "5e44160271.2023-11-27" +} \ No newline at end of file diff --git a/tests/resources/rule-of-rules-with-connector.json b/tests/resources/rule-of-rules-with-connector.json new file mode 100644 index 0000000..c22e040 --- /dev/null +++ b/tests/resources/rule-of-rules-with-connector.json @@ -0,0 +1,122 @@ +{ + "id": "demo@0.1.0", + "nodes": { + "0": { + "id": 0, + "data": {}, + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [ + { + "node": 4, + "input": "input_void", + "data": {} + } + ] + }, + "output_down_void": { + "connections": [] + } + }, + "position": [ + -298.1328125, + 175.67578125 + ], + "name": "Start" + }, + "2": { + "id": 2, + "data": { + "value": "{\"id\":\"demo@0.1.0\",\"nodes\":{\"0\":{\"id\":0,\"data\":{},\"inputs\":{},\"outputs\":{\"output_up_void\":{\"connections\":[{\"node\":5,\"input\":\"input_void\",\"data\":{}}]},\"output_down_void\":{\"connections\":[{\"node\":10,\"input\":\"input_void\",\"data\":{}}]}},\"position\":[36.999198273089995,200.90571277556646],\"name\":\"Start\"},\"4\":{\"id\":4,\"data\":{\"message\":null},\"inputs\":{\"input_value\":{\"connections\":[{\"node\":9,\"output\":\"output_value\",\"data\":{}}]}},\"outputs\":{},\"position\":[945.7727887309584,74.29575409540634],\"name\":\"Output\"},\"5\":{\"id\":5,\"data\":{\"name\":\"prediction\",\"service\":\"service-name\",\"identifier\":\"ml-model-name\",\"default\":null},\"inputs\":{\"input_void\":{\"connections\":[{\"node\":0,\"output\":\"output_up_void\",\"data\":{}}]}},\"outputs\":{\"output_value\":{\"connections\":[{\"node\":9,\"input\":\"input_value_0\",\"data\":{}}]}},\"position\":[395.07074381698476,-177.8616133049862],\"name\":\"ConnectorV0\"},\"9\":{\"id\":9,\"data\":{\"operator\":\"*\"},\"inputs\":{\"input_value_0\":{\"connections\":[{\"node\":5,\"output\":\"output_value\",\"data\":{}}]},\"input_value_1\":{\"connections\":[{\"node\":10,\"output\":\"output_value\",\"data\":{}}]}},\"outputs\":{\"output_value\":{\"connections\":[{\"node\":4,\"input\":\"input_value\",\"data\":{}}]}},\"position\":[693.0432160700918,41.0955985999175],\"name\":\"Math\"},\"10\":{\"id\":10,\"data\":{\"name\":\"multiplier\",\"default\":null},\"inputs\":{\"input_void\":{\"connections\":[{\"node\":0,\"output\":\"output_down_void\",\"data\":{}}]}},\"outputs\":{\"output_value\":{\"connections\":[{\"node\":9,\"input\":\"input_value_1\",\"data\":{}}]}},\"position\":[401.3893004299958,231.106029062272],\"name\":\"Input\"}},\"version\":\"5e44160271.2023-11-27\"}", + "name": "", + "default": null + }, + "inputs": { + "input_multiplier": { + "connections": [ + { + "node": 4, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 3, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 239.1640625, + -156.0390625 + ], + "name": "FlowV0" + }, + "3": { + "id": 3, + "data": { + "message": null + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 2, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 517.9873916270117, + -27.149735568960978 + ], + "name": "Output" + }, + "4": { + "id": 4, + "data": { + "name": "var", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_up_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 2, + "input": "input_multiplier", + "data": {} + } + ] + } + }, + "position": [ + -16.37890625, + 6.32421875 + ], + "name": "Input" + } + }, + "version": "ffe9a0173a.2023-11-27" +} \ No newline at end of file diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index 4701d41..f9f7842 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -161,6 +161,36 @@ def test_flows(filename, in_values, expected_out_values): {"output": 25, "message": None}, ], ), + ( + "connector-rule", + [ + {"prediction": "1", "multiplier": "1"}, + {"prediction": "2", "multiplier": "1"}, + {"prediction": "3", "multiplier": "1"}, + {"prediction": "4", "multiplier": "1"}, + ], + [ + {"output": 1.0, "message": None}, + {"output": 2.0, "message": None}, + {"output": 3.0, "message": None}, + {"output": 4.0, "message": None}, + ], + ), + ( + "rule-of-rules-with-connector", + [ + {"prediction": "1", "var": "1"}, + {"prediction": "2", "var": "1"}, + {"prediction": "3", "var": "1"}, + {"prediction": "4", "var": "1"}, + ], + [ + {"output": 1.0, "message": None}, + {"output": 2.0, "message": None}, + {"output": 3.0, "message": None}, + {"output": 4.0, "message": None}, + ], + ), ], ) def test_create_from_json(filename, in_values, expected_out_values): diff --git a/tests/test_nodes/test_connectors.py b/tests/test_nodes/test_connectors.py new file mode 100644 index 0000000..7215265 --- /dev/null +++ b/tests/test_nodes/test_connectors.py @@ -0,0 +1,28 @@ +from retrack.nodes.connectors import BaseConnector +import pytest + + +@pytest.fixture +def connector_dict(): + return { + "id": 1, + "data": {"name": "example", "default": "Hello World"}, + "inputs": {}, + "outputs": { + "output_value": { + "connections": [{"node": 0, "input": "input_void", "data": {}}] + } + }, + "position": [-444.37109375, 175.50390625], + "name": "Connector", + } + + +def test_create_base_connector(connector_dict): + connector = BaseConnector(**connector_dict) + + assert isinstance(connector, BaseConnector) + assert connector.kind() == "connector" + + with pytest.raises(NotImplementedError): + _ = connector.generate_input_nodes() From c111a87da18c52778419863e1da6c2c3c2b4d54f Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 10:55:16 -0300 Subject: [PATCH 04/17] Refactor component registry --- retrack/engine/parser.py | 91 ++------------------- retrack/engine/runner.py | 20 +++-- retrack/utils/component_registry.py | 122 ++++++++++++++++++++++++++++ tests/test_parser.py | 2 +- 4 files changed, 145 insertions(+), 90 deletions(-) create mode 100644 retrack/utils/component_registry.py diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index f5155aa..fe3ff17 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -4,6 +4,7 @@ from retrack import nodes, validators from retrack.utils.registry import Registry +from retrack.utils.component_registry import ComponentRegistry from retrack.nodes.base import NodeKind import json @@ -21,7 +22,7 @@ def __init__( validate_version: bool = True, ): self.__graph_data = graph_data - self.__components_registry = Registry() + self.__components_registry = ComponentRegistry() self._execution_order = None self.__edges = None self._raise_if_null_version = raise_if_null_version @@ -34,10 +35,7 @@ def __init__( self._validate_graph(validator_registry) - self._set_indexes_by_name_map() - self._set_indexes_by_kind_map() self._set_execution_order() - self._set_indexes_by_memory_type_map() self._set_version() self._set_input_nodes_from_connectors() @@ -50,7 +48,7 @@ def version(self) -> str: return self._version @property - def components_registry(self) -> Registry: + def components_registry(self) -> ComponentRegistry: return self.__components_registry @property @@ -123,96 +121,25 @@ def _validate_graph(self, validator_registry: Registry): def get_by_id(self, id_: str) -> nodes.BaseNode: return self.components.get(id_) - @property - def indexes_by_name_map(self) -> typing.Dict[str, typing.List[str]]: - return self._indexes_by_name_map - - def _set_indexes_by_name_map(self): - self._indexes_by_name_map = {} - - for node_id, node in self.components.items(): - node_name = node.__class__.__name__.lower() - if node_name not in self._indexes_by_name_map: - self._indexes_by_name_map[node_name] = [] - - self._indexes_by_name_map[node_name].append(node_id) - - def get_by_name(self, name: str) -> typing.List[nodes.BaseNode]: - name = name.lower() - return [self.get_by_id(id_) for id_ in self.indexes_by_name_map.get(name, [])] - - @property - def indexes_by_kind_map(self) -> typing.Dict[str, typing.List[str]]: - return self._indexes_by_kind_map - - def _set_indexes_by_kind_map(self): - self._indexes_by_kind_map = {} - - for node_id, node in self.components.items(): - if node.kind() not in self._indexes_by_kind_map: - self._indexes_by_kind_map[node.kind()] = [] - - self._indexes_by_kind_map[node.kind()].append(node_id) - - def get_by_kind(self, kind: str) -> typing.List[nodes.BaseNode]: - return [self.get_by_id(id_) for id_ in self.indexes_by_kind_map.get(kind, [])] - - @property - def indexes_by_memory_type_map(self) -> typing.Dict[str, typing.List[str]]: - return self._indexes_by_memory_type_map - - def _set_indexes_by_memory_type_map(self): - self._indexes_by_memory_type_map = {} - - for node_id, node in self.components.items(): - memory_type = node.memory_type() - if memory_type not in self.indexes_by_memory_type_map: - self._indexes_by_memory_type_map[memory_type] = [] - - self._indexes_by_memory_type_map[memory_type].append(node_id) - - def get_by_memory_type(self, memory_type: str) -> typing.List[nodes.BaseNode]: - return [ - self.get_by_id(id_) - for id_ in self.indexes_by_memory_type_map.get(memory_type, []) - ] - @property def execution_order(self) -> typing.List[str]: return self._execution_order def _set_execution_order(self): - start_nodes = self.get_by_name("start") + start_nodes = self.components_registry.get_by_name("start") self._execution_order = self._walk(start_nodes[0].id, []) - def get_node_connections( - self, node_id: str, is_input: bool = True, filter_by_connector=None - ): - node_dict = self.get_by_id(node_id).model_dump(by_alias=True) - - connectors = node_dict.get("inputs" if is_input else "outputs", {}) - result = [] - - for connector_name, value in connectors.items(): - if ( - filter_by_connector is not None - and connector_name != filter_by_connector - ): - continue - - for connection in value["connections"]: - result.append(connection["node"]) - return result - def _walk(self, actual_id: str, skiped_ids: list): skiped_ids.append(actual_id) - output_ids = self.get_node_connections(actual_id, is_input=False) + output_ids = self.components_registry.get_node_output_connections(actual_id) for next_id in output_ids: if next_id not in skiped_ids: - next_node_input_ids = self.get_node_connections(next_id, is_input=True) + next_node_input_ids = ( + self.components_registry.get_node_input_connections(next_id) + ) run_next = True for next_node_input_id in next_node_input_ids: if next_node_input_id not in skiped_ids: @@ -251,7 +178,7 @@ def _set_version(self): ) def _set_input_nodes_from_connectors(self): - connector_nodes = self.get_by_kind(NodeKind.CONNECTOR) + connector_nodes = self.components_registry.get_by_kind(NodeKind.CONNECTOR) for connector_node in connector_nodes: input_nodes = connector_node.generate_input_nodes() diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index b6b52f5..beab6f0 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -22,7 +22,9 @@ def __init__(self, parser: Parser, name: str = None): self._set_constants() self._set_input_columns() self._set_internal_runners() - self._request_manager = RequestManager(self._parser.get_by_kind(NodeKind.INPUT)) + self._request_manager = RequestManager( + self._parser.components_registry.get_by_kind(NodeKind.INPUT) + ) @classmethod def from_json(cls, data: typing.Union[str, dict], name: str = None, **kwargs): @@ -65,14 +67,16 @@ def constants(self) -> dict: return self._constants def _set_constants(self): - constant_nodes = self.parser.get_by_memory_type(NodeMemoryType.CONSTANT) + constant_nodes = self.parser.components_registry.get_by_memory_type( + NodeMemoryType.CONSTANT + ) self._constants = {} for node in constant_nodes: for output_connector_name, _ in node.outputs: self._constants[f"{node.id}@{output_connector_name}"] = node.data.value def _set_internal_runners(self): - for node_id in self.parser.indexes_by_name_map.get( + for node_id in self.parser.components_registry.indexes_by_name_map.get( constants.FLOW_NODE_NAME, [] ): try: @@ -90,7 +94,7 @@ def input_columns(self) -> dict: return self._input_columns def _set_input_columns(self): - input_nodes = self._parser.get_by_kind(NodeKind.INPUT) + input_nodes = self._parser.components_registry.get_by_kind(NodeKind.INPUT) self._input_columns = { f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name for node in input_nodes @@ -101,11 +105,13 @@ def reset(self): self._filters = {} def __set_output_connection_filters( - self, node_id: str, filter: typing.Any, filter_by_connector=None + self, node_id: str, filter: typing.Any, connector_filter=None ): if filter is not None: - output_connections = self.parser.get_node_connections( - node_id, is_input=False, filter_by_connector=filter_by_connector + output_connections = ( + self.parser.components_registry.get_node_output_connections( + node_id, connector_filter=connector_filter + ) ) for output_connection_id in output_connections: if self._filters.get(output_connection_id, None) is None: diff --git a/retrack/utils/component_registry.py b/retrack/utils/component_registry.py new file mode 100644 index 0000000..31e2532 --- /dev/null +++ b/retrack/utils/component_registry.py @@ -0,0 +1,122 @@ +import typing +from retrack.nodes.base import BaseNode +from retrack.utils.registry import Registry + + +class ComponentRegistry(Registry): + """A registry to store instances of BaseNode (aka Components). + + It also provides indexes to access the data by name, kind and memory type.""" + + def __init__(self, case_sensitive: bool = False): + super().__init__(case_sensitive=case_sensitive) + self._indexes_by_name_map = {} + self._indexes_by_kind_map = {} + self._indexes_by_memory_type_map = {} + + @property + def indexes_by_name_map(self) -> typing.Dict[str, typing.List[str]]: + return self._indexes_by_name_map + + @property + def indexes_by_kind_map(self) -> typing.Dict[str, typing.List[str]]: + return self._indexes_by_kind_map + + @property + def indexes_by_memory_type_map(self) -> typing.Dict[str, typing.List[str]]: + return self._indexes_by_memory_type_map + + def __register_in_indexes_by_name_map(self, name: str, data: BaseNode): + node_name = data.__class__.__name__.lower() + if node_name not in self._indexes_by_name_map: + self._indexes_by_name_map[node_name] = [] + + self._indexes_by_name_map[node_name].append(name) + + def __register_in_indexes_by_kind_map(self, name: str, data: BaseNode): + node_kind = data.kind() + if node_kind not in self._indexes_by_kind_map: + self._indexes_by_kind_map[node_kind] = [] + + self._indexes_by_kind_map[node_kind].append(name) + + def __register_in_indexes_by_memory_type_map(self, name: str, data: BaseNode): + memory_type = data.memory_type() + + if memory_type not in self._indexes_by_memory_type_map: + self._indexes_by_memory_type_map[memory_type] = [] + + self._indexes_by_memory_type_map[memory_type].append(name) + + def __unregister_from_indexes_by_name_map(self, name: str, data: BaseNode): + node_name = data.__class__.__name__.lower() + self._indexes_by_name_map[node_name].remove(name) + + def __unregister_from_indexes_by_kind_map(self, name: str, data: BaseNode): + node_kind = data.kind() + self._indexes_by_kind_map[node_kind].remove(name) + + def __unregister_from_indexes_by_memory_type_map(self, name: str, data: BaseNode): + memory_type = data.memory_type() + self._indexes_by_memory_type_map[memory_type].remove(name) + + def register(self, name: str, data: BaseNode, overwrite: bool = False): + """Register an entry.""" + if not isinstance(data, BaseNode): + raise ValueError("data must be a BaseNode instance.") + + super().register(name, data, overwrite=overwrite) + + self.__register_in_indexes_by_name_map(name, data) + self.__register_in_indexes_by_kind_map(name, data) + self.__register_in_indexes_by_memory_type_map(name, data) + + def unregister(self, name: str): + """Unregister an entry.""" + if not self._case_sensitive: + name = name.lower() + + data = self._data.pop(name, None) + + if data is None: + return + + self.__unregister_from_indexes_by_name_map(name, data) + self.__unregister_from_indexes_by_kind_map(name, data) + self.__unregister_from_indexes_by_memory_type_map(name, data) + + def get_by_name(self, name: str) -> typing.List[BaseNode]: + name = name.lower() + return [self.get(id_) for id_ in self.indexes_by_name_map.get(name, [])] + + def get_by_kind(self, kind: str) -> typing.List[BaseNode]: + return [self.get(id_) for id_ in self.indexes_by_kind_map.get(kind, [])] + + def get_by_memory_type(self, memory_type: str) -> typing.List[BaseNode]: + return [ + self.get(id_) + for id_ in self.indexes_by_memory_type_map.get(memory_type, []) + ] + + def _filter_connectors(self, connectors, connector_filter): + result = [] + + for connector_name, value in connectors.items(): + if connector_filter is not None and connector_name != connector_filter: + continue + + for connection in value["connections"]: + result.append(connection["node"]) + return result + + def get_node_input_connections(self, node_id: str, connector_filter=None): + node_dict = self.get(node_id).model_dump(by_alias=True) + + connectors = node_dict.get("inputs", {}) + return self._filter_connectors(connectors, connector_filter) + + def get_node_output_connections(self, node_id: str, connector_filter=None): + node_dict = self.get(node_id).model_dump(by_alias=True) + + connectors = node_dict.get("outputs", {}) + return self._filter_connectors(connectors, connector_filter) diff --git a/tests/test_parser.py b/tests/test_parser.py index a94a9eb..15976e4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -27,7 +27,7 @@ def test_parser_extract(data_filename, expected_tokens): input_data = json.load(f) parser = Parser(input_data) - assert parser.indexes_by_name_map == expected_tokens + assert parser.components_registry.indexes_by_name_map == expected_tokens def test_parser_with_unknown_node(): From 692864cfe6a7d7305e399a415b4075b635a4f2bc Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 11:00:39 -0300 Subject: [PATCH 05/17] Refactor component retrieval in parser and runner --- retrack/engine/parser.py | 9 +-------- retrack/engine/runner.py | 4 ++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index fe3ff17..a2374bd 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -51,10 +51,6 @@ def version(self) -> str: def components_registry(self) -> ComponentRegistry: return self.__components_registry - @property - def components(self) -> typing.Dict[str, nodes.BaseNode]: - return self.components_registry.data - @staticmethod def _check_input_data(data: dict): if not isinstance(data, dict): @@ -108,7 +104,7 @@ def edges(self) -> typing.List[typing.Tuple[str, str]]: def _set_edges(self): self.__edges = [] - for node_id, node in self.components.items(): + for node_id, node in self.components_registry.data.items(): for _, output_connection in node.outputs: for c in output_connection.connections: self.__edges.append((node_id, c.node)) @@ -118,9 +114,6 @@ def _validate_graph(self, validator_registry: Registry): if not validator.validate(graph_data=self.graph_data, edges=self.edges): raise ValueError(f"Invalid graph data: {validator_name}") - def get_by_id(self, id_: str) -> nodes.BaseNode: - return self.components.get(id_) - @property def execution_order(self) -> typing.List[str]: return self._execution_order diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index beab6f0..7b545bc 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -80,7 +80,7 @@ def _set_internal_runners(self): constants.FLOW_NODE_NAME, [] ): try: - node_data = self.parser.get_by_id(node_id).data + node_data = self.parser.components_registry.get(node_id).data self._internal_runners[node_id] = Runner.from_json( node_data.parsed_value(), name=node_data.name ) @@ -181,7 +181,7 @@ def __run_node(self, node_id: str): # if there is a filter, we need to set the children nodes to receive filtered data self.__set_output_connection_filters(node_id, current_node_filter) - node = self.parser.get_by_id(node_id) + node = self.parser.components_registry.get(node_id) if node.memory_type == NodeMemoryType.CONSTANT: return From 3888c2498a6d79c880fbbd8c3be659f5637fa527 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 12:04:35 -0300 Subject: [PATCH 06/17] Removing responsabilities from Parser --- retrack/engine/parser.py | 144 +++------------------------- retrack/utils/component_registry.py | 40 +++++--- retrack/utils/graph.py | 121 +++++++++++++++++++++++ 3 files changed, 163 insertions(+), 142 deletions(-) create mode 100644 retrack/utils/graph.py diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index a2374bd..a7a2edb 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -1,14 +1,11 @@ import typing -import hashlib from retrack import nodes, validators from retrack.utils.registry import Registry from retrack.utils.component_registry import ComponentRegistry from retrack.nodes.base import NodeKind -import json - -from unidecode import unidecode +from retrack.utils import graph class Parser: @@ -21,22 +18,22 @@ def __init__( raise_if_null_version: bool = False, validate_version: bool = True, ): + self.__components_registry = graph.create_component_registry( + graph_data, nodes_registry, dynamic_nodes_registry + ) + self._version = graph.validate_version( + graph_data, raise_if_null_version, validate_version + ) self.__graph_data = graph_data - self.__components_registry = ComponentRegistry() - self._execution_order = None - self.__edges = None - self._raise_if_null_version = raise_if_null_version - self._validate_version = validate_version - self._check_input_data(self.graph_data) - - self._set_components(nodes_registry, dynamic_nodes_registry) - self._set_edges() + graph.validate_with_validators( + self.graph_data, + self.components_registry.calculate_edges(), + validator_registry, + ) - self._validate_graph(validator_registry) + self._execution_order = graph.get_execution_order(self.components_registry) - self._set_execution_order() - self._set_version() self._set_input_nodes_from_connectors() @property @@ -51,125 +48,10 @@ def version(self) -> str: def components_registry(self) -> ComponentRegistry: return self.__components_registry - @staticmethod - def _check_input_data(data: dict): - if not isinstance(data, dict): - raise TypeError( - "Data must be a dictionary. Instead got: " + str(type(data)) - ) - if "nodes" not in data: - raise ValueError("No nodes found in data") - if not isinstance(data["nodes"], dict): - raise TypeError( - "Nodes must be a dictionary. Instead got: " + str(type(data["nodes"])) - ) - - @staticmethod - def _check_node_name(node_name: str, node_id: str): - if node_name is None: - raise ValueError(f"Node {node_id} has no name") - if not isinstance(node_name, str): - raise TypeError(f"Node {node_id} name must be a string") - - def _set_components( - self, nodes_registry: Registry, dynamic_nodes_registry: Registry - ): - for node_id, node_metadata in self.graph_data["nodes"].items(): - if node_id in self.components_registry: - raise ValueError(f"Duplicate node id: {node_id}") - - node_name = node_metadata.get("name", None) - self._check_node_name(node_name, node_id) - - node_name = node_name.lower() - - node_factory = dynamic_nodes_registry.get(node_name) - - if node_factory is not None: - validation_model = node_factory(**node_metadata) - else: - validation_model = nodes_registry.get(node_name) - - if validation_model is None: - raise ValueError(f"Unknown node name: {node_name}") - - self.components_registry.register( - node_id, validation_model(**node_metadata) - ) - - @property - def edges(self) -> typing.List[typing.Tuple[str, str]]: - return self.__edges - - def _set_edges(self): - self.__edges = [] - - for node_id, node in self.components_registry.data.items(): - for _, output_connection in node.outputs: - for c in output_connection.connections: - self.__edges.append((node_id, c.node)) - - def _validate_graph(self, validator_registry: Registry): - for validator_name, validator in validator_registry.data.items(): - if not validator.validate(graph_data=self.graph_data, edges=self.edges): - raise ValueError(f"Invalid graph data: {validator_name}") - @property def execution_order(self) -> typing.List[str]: return self._execution_order - def _set_execution_order(self): - start_nodes = self.components_registry.get_by_name("start") - - self._execution_order = self._walk(start_nodes[0].id, []) - - def _walk(self, actual_id: str, skiped_ids: list): - skiped_ids.append(actual_id) - - output_ids = self.components_registry.get_node_output_connections(actual_id) - - for next_id in output_ids: - if next_id not in skiped_ids: - next_node_input_ids = ( - self.components_registry.get_node_input_connections(next_id) - ) - run_next = True - for next_node_input_id in next_node_input_ids: - if next_node_input_id not in skiped_ids: - run_next = False - break - - if run_next: - self._walk(next_id, skiped_ids) - - return skiped_ids - - def _set_version(self): - self._version = self.graph_data.get("version", None) - - graph_json_content = ( - json.dumps(self.graph_data["nodes"], ensure_ascii=False) - .replace(": ", ":") - .replace("\\", "") - .replace('"', "") - .replace(", ", ",") - ) - graph_json_content = unidecode(graph_json_content, errors="strict") - calculated_hash = hashlib.sha256(graph_json_content.encode()).hexdigest()[:10] - - if self.version is None: - if self._raise_if_null_version: - raise ValueError("Missing version") - - self._version = f"{calculated_hash}.dynamic" - else: - file_version_hash = self.version.split(".")[0] - - if file_version_hash != calculated_hash and self._validate_version: - raise ValueError( - f"Invalid version. Graph data has changed and the hash is different: {calculated_hash} != {file_version_hash}" - ) - def _set_input_nodes_from_connectors(self): connector_nodes = self.components_registry.get_by_kind(NodeKind.CONNECTOR) diff --git a/retrack/utils/component_registry.py b/retrack/utils/component_registry.py index 31e2532..0657a94 100644 --- a/retrack/utils/component_registry.py +++ b/retrack/utils/component_registry.py @@ -26,21 +26,23 @@ def indexes_by_kind_map(self) -> typing.Dict[str, typing.List[str]]: def indexes_by_memory_type_map(self) -> typing.Dict[str, typing.List[str]]: return self._indexes_by_memory_type_map - def __register_in_indexes_by_name_map(self, name: str, data: BaseNode): + def __register_in_indexes_by_name_map(self, name: str, data: BaseNode) -> None: node_name = data.__class__.__name__.lower() if node_name not in self._indexes_by_name_map: self._indexes_by_name_map[node_name] = [] self._indexes_by_name_map[node_name].append(name) - def __register_in_indexes_by_kind_map(self, name: str, data: BaseNode): + def __register_in_indexes_by_kind_map(self, name: str, data: BaseNode) -> None: node_kind = data.kind() if node_kind not in self._indexes_by_kind_map: self._indexes_by_kind_map[node_kind] = [] self._indexes_by_kind_map[node_kind].append(name) - def __register_in_indexes_by_memory_type_map(self, name: str, data: BaseNode): + def __register_in_indexes_by_memory_type_map( + self, name: str, data: BaseNode + ) -> None: memory_type = data.memory_type() if memory_type not in self._indexes_by_memory_type_map: @@ -48,19 +50,21 @@ def __register_in_indexes_by_memory_type_map(self, name: str, data: BaseNode): self._indexes_by_memory_type_map[memory_type].append(name) - def __unregister_from_indexes_by_name_map(self, name: str, data: BaseNode): + def __unregister_from_indexes_by_name_map(self, name: str, data: BaseNode) -> None: node_name = data.__class__.__name__.lower() self._indexes_by_name_map[node_name].remove(name) - def __unregister_from_indexes_by_kind_map(self, name: str, data: BaseNode): + def __unregister_from_indexes_by_kind_map(self, name: str, data: BaseNode) -> None: node_kind = data.kind() self._indexes_by_kind_map[node_kind].remove(name) - def __unregister_from_indexes_by_memory_type_map(self, name: str, data: BaseNode): + def __unregister_from_indexes_by_memory_type_map( + self, name: str, data: BaseNode + ) -> None: memory_type = data.memory_type() self._indexes_by_memory_type_map[memory_type].remove(name) - def register(self, name: str, data: BaseNode, overwrite: bool = False): + def register(self, name: str, data: BaseNode, overwrite: bool = False) -> None: """Register an entry.""" if not isinstance(data, BaseNode): raise ValueError("data must be a BaseNode instance.") @@ -71,7 +75,7 @@ def register(self, name: str, data: BaseNode, overwrite: bool = False): self.__register_in_indexes_by_kind_map(name, data) self.__register_in_indexes_by_memory_type_map(name, data) - def unregister(self, name: str): + def unregister(self, name: str) -> None: """Unregister an entry.""" if not self._case_sensitive: name = name.lower() @@ -98,7 +102,7 @@ def get_by_memory_type(self, memory_type: str) -> typing.List[BaseNode]: for id_ in self.indexes_by_memory_type_map.get(memory_type, []) ] - def _filter_connectors(self, connectors, connector_filter): + def _filter_connectors(self, connectors, connector_filter) -> typing.List[str]: result = [] for connector_name, value in connectors.items(): @@ -109,14 +113,28 @@ def _filter_connectors(self, connectors, connector_filter): result.append(connection["node"]) return result - def get_node_input_connections(self, node_id: str, connector_filter=None): + def get_node_input_connections( + self, node_id: str, connector_filter=None + ) -> typing.List[str]: node_dict = self.get(node_id).model_dump(by_alias=True) connectors = node_dict.get("inputs", {}) return self._filter_connectors(connectors, connector_filter) - def get_node_output_connections(self, node_id: str, connector_filter=None): + def get_node_output_connections( + self, node_id: str, connector_filter=None + ) -> typing.List[str]: node_dict = self.get(node_id).model_dump(by_alias=True) connectors = node_dict.get("outputs", {}) return self._filter_connectors(connectors, connector_filter) + + def calculate_edges(self) -> typing.List[typing.Tuple[str, str]]: + edges = [] + + for node_id, node in self.data.items(): + for _, output_connection in node.outputs: + for c in output_connection.connections: + edges.append((node_id, c.node)) + + return edges diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py new file mode 100644 index 0000000..6c2c7c6 --- /dev/null +++ b/retrack/utils/graph.py @@ -0,0 +1,121 @@ +import json + +from unidecode import unidecode +import hashlib +from retrack.utils.registry import Registry +from retrack.utils.component_registry import ComponentRegistry + + +def validate_version( + graph_data: dict, raise_if_null_version: bool, validate_version: bool +) -> str: + version = graph_data.get("version", None) + + graph_json_content = ( + json.dumps(graph_data["nodes"], ensure_ascii=False) + .replace(": ", ":") + .replace("\\", "") + .replace('"', "") + .replace(", ", ",") + ) + graph_json_content = unidecode(graph_json_content, errors="strict") + calculated_hash = hashlib.sha256(graph_json_content.encode()).hexdigest()[:10] + + if version is None: + if raise_if_null_version: + raise ValueError("Missing version") + + return f"{calculated_hash}.dynamic" + + file_version_hash = version.split(".")[0] + + if file_version_hash != calculated_hash and validate_version: + raise ValueError( + f"Invalid version. Graph data has changed and the hash is different: {calculated_hash} != {file_version_hash}" + ) + + return version + + +def validate_data(data: dict) -> dict: + if not isinstance(data, dict): + raise TypeError("Data must be a dictionary. Instead got: " + str(type(data))) + if "nodes" not in data: + raise ValueError("No nodes found in data") + if not isinstance(data["nodes"], dict): + raise TypeError( + "Nodes must be a dictionary. Instead got: " + str(type(data["nodes"])) + ) + return data + + +def validate_with_validators( + graph_data: dict, edges: dict, validator_registry: Registry +): + for validator_name, validator in validator_registry.data.items(): + if not validator.validate(graph_data=graph_data, edges=edges): + raise ValueError(f"Invalid graph data: {validator_name}") + + +def check_node_name(node_name: str, node_id: str): + if node_name is None: + raise ValueError(f"Node {node_id} has no name") + if not isinstance(node_name, str): + raise TypeError(f"Node {node_id} name must be a string") + + +def create_component_registry( + graph_data: dict, nodes_registry: Registry, dynamic_nodes_registry: Registry +) -> ComponentRegistry: + components_registry = ComponentRegistry() + graph_data = validate_data(graph_data) + for node_id, node_metadata in graph_data["nodes"].items(): + if node_id in components_registry: + raise ValueError(f"Duplicate node id: {node_id}") + + node_name = node_metadata.get("name", None) + check_node_name(node_name, node_id) + + node_name = node_name.lower() + + node_factory = dynamic_nodes_registry.get(node_name) + + if node_factory is not None: + validation_model = node_factory(**node_metadata) + else: + validation_model = nodes_registry.get(node_name) + + if validation_model is None: + raise ValueError(f"Unknown node name: {node_name}") + + components_registry.register(node_id, validation_model(**node_metadata)) + + return components_registry + + +def walk(actual_id: str, skiped_ids: list, components_registry: ComponentRegistry): + skiped_ids.append(actual_id) + + output_ids = components_registry.get_node_output_connections(actual_id) + + for next_id in output_ids: + if next_id not in skiped_ids: + next_node_input_ids = components_registry.get_node_input_connections( + next_id + ) + run_next = True + for next_node_input_id in next_node_input_ids: + if next_node_input_id not in skiped_ids: + run_next = False + break + + if run_next: + walk(next_id, skiped_ids, components_registry) + + return skiped_ids + + +def get_execution_order(components_registry: ComponentRegistry): + start_nodes = components_registry.get_by_name("start") + + return walk(start_nodes[0].id, [], components_registry) From ceadfe75b802b0dfbb69d6511c25cf1b534f088f Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 16:55:48 -0300 Subject: [PATCH 07/17] Remove import dependencies from Parser --- retrack/engine/parser.py | 6 ++--- retrack/engine/runner.py | 19 +++++++++++--- retrack/nodes/base.py | 3 +++ retrack/utils/graph.py | 11 +++++++- tests/test_engine/test_runner.py | 17 +++++++++++-- tests/test_parser.py | 43 ++++++++++++++++++++++++++------ 6 files changed, 83 insertions(+), 16 deletions(-) diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index a7a2edb..3488e35 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -1,7 +1,7 @@ import typing -from retrack import nodes, validators +from retrack import validators from retrack.utils.registry import Registry from retrack.utils.component_registry import ComponentRegistry from retrack.nodes.base import NodeKind @@ -12,8 +12,8 @@ class Parser: def __init__( self, graph_data: dict, - nodes_registry: Registry = nodes.registry(), - dynamic_nodes_registry: Registry = nodes.dynamic_nodes_registry(), + nodes_registry: Registry, + dynamic_nodes_registry: Registry, validator_registry: Registry = validators.registry(), raise_if_null_version: bool = False, validate_version: bool = True, diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 7b545bc..81c5693 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -9,7 +9,8 @@ from retrack.engine.parser import Parser from retrack.engine.request_manager import RequestManager from retrack.nodes.base import NodeKind, NodeMemoryType -from retrack.utils import constants +from retrack.utils import constants, registry +from retrack import nodes class Runner: @@ -27,7 +28,14 @@ def __init__(self, parser: Parser, name: str = None): ) @classmethod - def from_json(cls, data: typing.Union[str, dict], name: str = None, **kwargs): + def from_json( + cls, + data: typing.Union[str, dict], + name: str = None, + nodes_registry: registry.Registry = nodes.registry(), + dynamic_nodes_registry: registry.Registry = nodes.dynamic_nodes_registry(), + **kwargs, + ): if isinstance(data, str) and data.endswith(".json"): if name is None: name = data @@ -35,7 +43,12 @@ def from_json(cls, data: typing.Union[str, dict], name: str = None, **kwargs): elif not isinstance(data, dict): raise ValueError("data must be a dict or a json file path") - parser = Parser(data, **kwargs) + parser = Parser( + data, + nodes_registry=nodes_registry, + dynamic_nodes_registry=dynamic_nodes_registry, + **kwargs, + ) return cls(parser, name=name) @property diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index 63560e4..8e419e5 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -84,3 +84,6 @@ def kind(self) -> NodeKind: def memory_type(self) -> NodeMemoryType: return NodeMemoryType.STATE + + def generate_input_nodes(self) -> typing.List["BaseNode"]: + return [] diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py index 6c2c7c6..34dccf6 100644 --- a/retrack/utils/graph.py +++ b/retrack/utils/graph.py @@ -81,13 +81,22 @@ def create_component_registry( node_factory = dynamic_nodes_registry.get(node_name) if node_factory is not None: - validation_model = node_factory(**node_metadata) + validation_model = node_factory( + **node_metadata, + nodes_registry=nodes_registry, + dynamic_nodes_registry=dynamic_nodes_registry, + ) else: validation_model = nodes_registry.get(node_name) if validation_model is None: raise ValueError(f"Unknown node name: {node_name}") + component = validation_model(**node_metadata) + + for input_node in component.generate_input_nodes(): + components_registry.register(input_node.id, input_node) + components_registry.register(node_id, validation_model(**node_metadata)) return components_registry diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index f9f7842..2d2e6a3 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -4,6 +4,7 @@ import pytest from retrack import Parser, Runner +from retrack import nodes @pytest.mark.parametrize( @@ -29,7 +30,13 @@ def test_flows_with_single_element(filename, in_values, expected_out_values): with open(f"tests/resources/{filename}.json", "r") as f: rule = json.load(f) - runner = Runner(Parser(rule)) + runner = Runner( + Parser( + rule, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) + ) out_values = runner.execute(pd.DataFrame([in_values])) assert isinstance(out_values, pd.DataFrame) @@ -107,7 +114,13 @@ def test_flows(filename, in_values, expected_out_values): with open(f"tests/resources/{filename}.json", "r") as f: rule = json.load(f) - runner = Runner(Parser(rule)) + runner = Runner( + Parser( + rule, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) + ) out_values = runner.execute(pd.DataFrame(in_values)) assert isinstance(out_values, pd.DataFrame) diff --git a/tests/test_parser.py b/tests/test_parser.py index 15976e4..ed42556 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -3,6 +3,7 @@ import pytest from retrack.engine.parser import Parser +from retrack import nodes @pytest.mark.parametrize( @@ -26,35 +27,63 @@ def test_parser_extract(data_filename, expected_tokens): with open(data_filename) as f: input_data = json.load(f) - parser = Parser(input_data) + parser = Parser( + input_data, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) assert parser.components_registry.indexes_by_name_map == expected_tokens def test_parser_with_unknown_node(): with pytest.raises(ValueError): - Parser({"nodes": {"1": {"name": "Unknown"}}}) + Parser( + {"nodes": {"1": {"name": "Unknown"}}}, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) def test_parser_invalid_input_data(): with pytest.raises(TypeError): - Parser("invalid data") + Parser( + "invalid data", + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) def test_parser_no_nodes(): with pytest.raises(ValueError): - Parser({}) + Parser( + {}, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) def test_parser_invalid_nodes(): with pytest.raises(TypeError): - Parser({"nodes": "invalid nodes"}) + Parser( + {"nodes": "invalid nodes"}, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) def test_parser_node_no_name(): with pytest.raises(ValueError): - Parser({"nodes": {"1": {}}}) + Parser( + {"nodes": {"1": {}}}, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) def test_parser_node_invalid_name(): with pytest.raises(TypeError): - Parser({"nodes": {"1": {"name": 1}}}) + Parser( + {"nodes": {"1": {"name": 1}}}, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ) From e60c095d3d1fbef61e6f309ebee787e761e64c4d Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 19:18:38 -0300 Subject: [PATCH 08/17] Remove parser class --- retrack/__init__.py | 8 +- retrack/engine/constructor.py | 31 +++ retrack/engine/parser.py | 61 ------ retrack/engine/{runner.py => rule.py} | 179 +++++++++++------- retrack/nodes/dynamic/flow.py | 27 ++- retrack/utils/graph.py | 38 ---- .../{test_runner.py => test_executor.py} | 45 ++--- tests/test_parser.py | 89 --------- 8 files changed, 189 insertions(+), 289 deletions(-) create mode 100644 retrack/engine/constructor.py delete mode 100644 retrack/engine/parser.py rename retrack/engine/{runner.py => rule.py} (63%) rename tests/test_engine/{test_runner.py => test_executor.py} (90%) delete mode 100644 tests/test_parser.py diff --git a/retrack/__init__.py b/retrack/__init__.py index ede84ed..b3cc423 100644 --- a/retrack/__init__.py +++ b/retrack/__init__.py @@ -1,11 +1,11 @@ -from retrack.engine.parser import Parser -from retrack.engine.runner import Runner +from retrack.engine.rule import Rule +from retrack.engine.constructor import from_json from retrack.nodes import registry as nodes_registry from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel __all__ = [ - "Parser", - "Runner", + "Rule", + "from_json", "BaseNode", "InputConnectionModel", "OutputConnectionModel", diff --git a/retrack/engine/constructor.py b/retrack/engine/constructor.py new file mode 100644 index 0000000..b9c1d6b --- /dev/null +++ b/retrack/engine/constructor.py @@ -0,0 +1,31 @@ +import typing + +import json + +from retrack.engine.rule import Rule, RuleExecutor +from retrack.utils import registry +from retrack import nodes + + +def from_json( + data: typing.Union[str, dict], + name: str = None, + nodes_registry: registry.Registry = nodes.registry(), + dynamic_nodes_registry: registry.Registry = nodes.dynamic_nodes_registry(), + **kwargs, +) -> RuleExecutor: + if isinstance(data, str) and data.endswith(".json"): + if name is None: + name = data + graph_data = json.loads(open(data).read()) + elif not isinstance(data, dict): + raise ValueError("data must be a dict or a json file path") + + rule = Rule.create( + graph_data=graph_data, + name=name, + nodes_registry=nodes_registry, + dynamic_nodes_registry=dynamic_nodes_registry, + **kwargs, + ) + return rule.executor diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py deleted file mode 100644 index 3488e35..0000000 --- a/retrack/engine/parser.py +++ /dev/null @@ -1,61 +0,0 @@ -import typing - - -from retrack import validators -from retrack.utils.registry import Registry -from retrack.utils.component_registry import ComponentRegistry -from retrack.nodes.base import NodeKind -from retrack.utils import graph - - -class Parser: - def __init__( - self, - graph_data: dict, - nodes_registry: Registry, - dynamic_nodes_registry: Registry, - validator_registry: Registry = validators.registry(), - raise_if_null_version: bool = False, - validate_version: bool = True, - ): - self.__components_registry = graph.create_component_registry( - graph_data, nodes_registry, dynamic_nodes_registry - ) - self._version = graph.validate_version( - graph_data, raise_if_null_version, validate_version - ) - self.__graph_data = graph_data - - graph.validate_with_validators( - self.graph_data, - self.components_registry.calculate_edges(), - validator_registry, - ) - - self._execution_order = graph.get_execution_order(self.components_registry) - - self._set_input_nodes_from_connectors() - - @property - def graph_data(self) -> dict: - return self.__graph_data - - @property - def version(self) -> str: - return self._version - - @property - def components_registry(self) -> ComponentRegistry: - return self.__components_registry - - @property - def execution_order(self) -> typing.List[str]: - return self._execution_order - - def _set_input_nodes_from_connectors(self): - connector_nodes = self.components_registry.get_by_kind(NodeKind.CONNECTOR) - - for connector_node in connector_nodes: - input_nodes = connector_node.generate_input_nodes() - for input_node in input_nodes: - self.components_registry.register(input_node.id, input_node) diff --git a/retrack/engine/runner.py b/retrack/engine/rule.py similarity index 63% rename from retrack/engine/runner.py rename to retrack/engine/rule.py index 81c5693..7b33059 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/rule.py @@ -1,63 +1,34 @@ +import pydantic +from retrack.utils.component_registry import ComponentRegistry +from retrack.utils.registry import Registry + +from retrack.utils import graph +from retrack import validators import typing -import json import numpy as np import pandas as pd -import pydantic -from retrack.engine.parser import Parser from retrack.engine.request_manager import RequestManager from retrack.nodes.base import NodeKind, NodeMemoryType -from retrack.utils import constants, registry -from retrack import nodes +from retrack.utils import constants -class Runner: - def __init__(self, parser: Parser, name: str = None): - self._parser = parser - self._name = name - self._internal_runners = {} +class RuleExecutor: + def __init__(self, rule: "Rule"): + self._rule = rule self._validated_payload = None self.reset() self._set_constants() self._set_input_columns() - self._set_internal_runners() self._request_manager = RequestManager( - self._parser.components_registry.get_by_kind(NodeKind.INPUT) + self._rule.components_registry.get_by_kind(NodeKind.INPUT) ) - @classmethod - def from_json( - cls, - data: typing.Union[str, dict], - name: str = None, - nodes_registry: registry.Registry = nodes.registry(), - dynamic_nodes_registry: registry.Registry = nodes.dynamic_nodes_registry(), - **kwargs, - ): - if isinstance(data, str) and data.endswith(".json"): - if name is None: - name = data - data = json.loads(open(data).read()) - elif not isinstance(data, dict): - raise ValueError("data must be a dict or a json file path") - - parser = Parser( - data, - nodes_registry=nodes_registry, - dynamic_nodes_registry=dynamic_nodes_registry, - **kwargs, - ) - return cls(parser, name=name) - - @property - def parser(self) -> Parser: - return self._parser - @property - def name(self) -> str: - return self._name + def rule(self) -> "Rule": + return self._rule @property def request_manager(self) -> RequestManager: @@ -80,7 +51,7 @@ def constants(self) -> dict: return self._constants def _set_constants(self): - constant_nodes = self.parser.components_registry.get_by_memory_type( + constant_nodes = self.rule.components_registry.get_by_memory_type( NodeMemoryType.CONSTANT ) self._constants = {} @@ -88,26 +59,12 @@ def _set_constants(self): for output_connector_name, _ in node.outputs: self._constants[f"{node.id}@{output_connector_name}"] = node.data.value - def _set_internal_runners(self): - for node_id in self.parser.components_registry.indexes_by_name_map.get( - constants.FLOW_NODE_NAME, [] - ): - try: - node_data = self.parser.components_registry.get(node_id).data - self._internal_runners[node_id] = Runner.from_json( - node_data.parsed_value(), name=node_data.name - ) - except Exception as e: - raise Exception( - f"Error setting internal runner for node {node_id}" - ) from e - @property def input_columns(self) -> dict: return self._input_columns def _set_input_columns(self): - input_nodes = self._parser.components_registry.get_by_kind(NodeKind.INPUT) + input_nodes = self._rule.components_registry.get_by_kind(NodeKind.INPUT) self._input_columns = { f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name for node in input_nodes @@ -122,7 +79,7 @@ def __set_output_connection_filters( ): if filter is not None: output_connections = ( - self.parser.components_registry.get_node_output_connections( + self.rule.components_registry.get_node_output_connections( node_id, connector_filter=connector_filter ) ) @@ -165,11 +122,6 @@ def __get_input_params( f"{connection['node']}@{connection['output']}", current_node_filter ) - if node_id in self._internal_runners: - input_params["runner"] = self._internal_runners[node_id] - for column_name, column_value in self._validated_payload.items(): - input_params[f"payload_{column_name}"] = column_value - return input_params def __set_state_data( @@ -194,7 +146,7 @@ def __run_node(self, node_id: str): # if there is a filter, we need to set the children nodes to receive filtered data self.__set_output_connection_filters(node_id, current_node_filter) - node = self.parser.components_registry.get(node_id) + node = self.rule.components_registry.get(node_id) if node.memory_type == NodeMemoryType.CONSTANT: return @@ -237,12 +189,12 @@ def execute( self.reset() self._states = self._create_initial_state_from_payload(payload_df) - for node_id in self.parser.execution_order: + for node_id in self.rule.execution_order: try: self.__run_node(node_id) except Exception as e: raise Exception( - f"Error running node {node_id} in {self.name} with version {self.parser.version}" + f"Error running node {node_id} in {self.rule.name} with version {self.rule.version}" ) from e if self.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: @@ -257,3 +209,96 @@ def execute( constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, ] ] + + +class Rule(pydantic.BaseModel): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + name: typing.Optional[str] = None + version: str + components_registry: ComponentRegistry + execution_order: typing.List[str] + _executor: RuleExecutor = None + + @property + def executor(self) -> RuleExecutor: + if self._executor is None: + self._executor = RuleExecutor(self) + return self._executor + + @classmethod + def create( + cls, + graph_data: dict, + nodes_registry: Registry, + dynamic_nodes_registry: Registry, + validator_registry: Registry = validators.registry(), + raise_if_null_version: bool = False, + validate_version: bool = True, + name: str = None, + ): + components_registry = create_component_registry( + graph_data, nodes_registry, dynamic_nodes_registry, validator_registry + ) + version = graph.validate_version( + graph_data, raise_if_null_version, validate_version + ) + graph_data = graph_data + + graph.validate_with_validators( + graph_data, + components_registry.calculate_edges(), + validator_registry, + ) + + execution_order = graph.get_execution_order(components_registry) + + return cls( + version=version, + components_registry=components_registry, + execution_order=execution_order, + name=name, + ) + + +def create_component_registry( + graph_data: dict, + nodes_registry: Registry, + dynamic_nodes_registry: Registry, + validator_registry: Registry, +) -> ComponentRegistry: + components_registry = ComponentRegistry() + graph_data = graph.validate_data(graph_data) + for node_id, node_metadata in graph_data["nodes"].items(): + if node_id in components_registry: + raise ValueError(f"Duplicate node id: {node_id}") + + node_name = node_metadata.get("name", None) + graph.check_node_name(node_name, node_id) + + node_name = node_name.lower() + + node_factory = dynamic_nodes_registry.get(node_name) + + if node_factory is not None: + validation_model = node_factory( + **node_metadata, + nodes_registry=nodes_registry, + dynamic_nodes_registry=dynamic_nodes_registry, + validator_registry=validator_registry, + rule_class=Rule, + ) + else: + validation_model = nodes_registry.get(node_name) + + if validation_model is None: + raise ValueError(f"Unknown node name: {node_name}") + + component = validation_model(**node_metadata) + + for input_node in component.generate_input_nodes(): + components_registry.register(input_node.id, input_node) + + components_registry.register(node_id, validation_model(**node_metadata)) + + return components_registry diff --git a/retrack/nodes/dynamic/flow.py b/retrack/nodes/dynamic/flow.py index de57c5b..e3e453b 100644 --- a/retrack/nodes/dynamic/flow.py +++ b/retrack/nodes/dynamic/flow.py @@ -7,6 +7,7 @@ from retrack.nodes.base import InputConnectionModel, OutputConnectionModel from retrack.nodes.dynamic.base import BaseDynamicIOModel, BaseDynamicNode +from retrack.utils.registry import Registry class FlowV0MetadataModel(pydantic.BaseModel): @@ -23,8 +24,22 @@ class FlowV0OutputsModel(pydantic.BaseModel): def flow_factory( - inputs: typing.Dict[str, typing.Any], **kwargs + inputs: typing.Dict[str, typing.Any], + nodes_registry: Registry, + dynamic_nodes_registry: Registry, + validator_registry: Registry, + data: dict, + rule_class, + **factory_kwargs, ) -> typing.Type[BaseDynamicNode]: + graph_data = json.loads(data["value"]) + rule_instance = rule_class.create( + graph_data=graph_data, + nodes_registry=nodes_registry, + dynamic_nodes_registry=dynamic_nodes_registry, + validator_registry=validator_registry, + name=data["name"], + ) input_fields = {} for name in inputs.keys(): @@ -42,10 +57,6 @@ def flow_factory( class FlowV0(BaseFlowV0Model): def run(self, **kwargs) -> typing.Dict[str, typing.Any]: - runner = kwargs.get("runner", None) - if runner is None: - raise ValueError("Missing runner") - inputs_in_kwargs = {} for name, value in kwargs.items(): @@ -54,8 +65,12 @@ def run(self, **kwargs) -> typing.Dict[str, typing.Any]: elif name.startswith("payload_"): inputs_in_kwargs[name[len("payload_") :]] = value - response = runner.execute(pd.DataFrame(inputs_in_kwargs)) + response = rule_instance.executor.execute(pd.DataFrame(inputs_in_kwargs)) return {"output_value": response["output"].values} + def generate_input_nodes(self): + # TODO: check inputs that do not have input nodes + return [] + return FlowV0 diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py index 34dccf6..43985e6 100644 --- a/retrack/utils/graph.py +++ b/retrack/utils/graph.py @@ -64,44 +64,6 @@ def check_node_name(node_name: str, node_id: str): raise TypeError(f"Node {node_id} name must be a string") -def create_component_registry( - graph_data: dict, nodes_registry: Registry, dynamic_nodes_registry: Registry -) -> ComponentRegistry: - components_registry = ComponentRegistry() - graph_data = validate_data(graph_data) - for node_id, node_metadata in graph_data["nodes"].items(): - if node_id in components_registry: - raise ValueError(f"Duplicate node id: {node_id}") - - node_name = node_metadata.get("name", None) - check_node_name(node_name, node_id) - - node_name = node_name.lower() - - node_factory = dynamic_nodes_registry.get(node_name) - - if node_factory is not None: - validation_model = node_factory( - **node_metadata, - nodes_registry=nodes_registry, - dynamic_nodes_registry=dynamic_nodes_registry, - ) - else: - validation_model = nodes_registry.get(node_name) - - if validation_model is None: - raise ValueError(f"Unknown node name: {node_name}") - - component = validation_model(**node_metadata) - - for input_node in component.generate_input_nodes(): - components_registry.register(input_node.id, input_node) - - components_registry.register(node_id, validation_model(**node_metadata)) - - return components_registry - - def walk(actual_id: str, skiped_ids: list, components_registry: ComponentRegistry): skiped_ids.append(actual_id) diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_executor.py similarity index 90% rename from tests/test_engine/test_runner.py rename to tests/test_engine/test_executor.py index 2d2e6a3..7dda2e9 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_executor.py @@ -3,8 +3,7 @@ import pandas as pd import pytest -from retrack import Parser, Runner -from retrack import nodes +from retrack import Rule, nodes, from_json @pytest.mark.parametrize( @@ -28,16 +27,15 @@ ) def test_flows_with_single_element(filename, in_values, expected_out_values): with open(f"tests/resources/{filename}.json", "r") as f: - rule = json.load(f) + graph_data = json.load(f) - runner = Runner( - Parser( - rule, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - ) - out_values = runner.execute(pd.DataFrame([in_values])) + executor = Rule.create( + graph_data, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ).executor + + out_values = executor.execute(pd.DataFrame([in_values])) assert isinstance(out_values, pd.DataFrame) assert out_values.to_dict(orient="records") == expected_out_values @@ -112,16 +110,15 @@ def test_flows_with_single_element(filename, in_values, expected_out_values): ) def test_flows(filename, in_values, expected_out_values): with open(f"tests/resources/{filename}.json", "r") as f: - rule = json.load(f) - - runner = Runner( - Parser( - rule, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - ) - out_values = runner.execute(pd.DataFrame(in_values)) + graph_data = json.load(f) + + executor = Rule.create( + graph_data, + nodes_registry=nodes.registry(), + dynamic_nodes_registry=nodes.dynamic_nodes_registry(), + ).executor + + out_values = executor.execute(pd.DataFrame(in_values)) assert isinstance(out_values, pd.DataFrame) assert out_values.to_dict(orient="records") == expected_out_values @@ -207,7 +204,7 @@ def test_flows(filename, in_values, expected_out_values): ], ) def test_create_from_json(filename, in_values, expected_out_values): - runner = Runner.from_json(f"tests/resources/{filename}.json") + runner = from_json(f"tests/resources/{filename}.json") out_values = runner.execute(pd.DataFrame(in_values)) assert isinstance(out_values, pd.DataFrame) @@ -216,11 +213,11 @@ def test_create_from_json(filename, in_values, expected_out_values): def test_create_from_json_with_invalid_type(): with pytest.raises(ValueError): - Runner.from_json(1) + from_json(1) def test_csv_table_with_if(): - runner = Runner.from_json("tests/resources/csv-table-with-if.json") + runner = from_json("tests/resources/csv-table-with-if.json") in_values = [ {"in_a": 0, "in_b": 0, "in_d": 0, "in_e": 0}, diff --git a/tests/test_parser.py b/tests/test_parser.py deleted file mode 100644 index ed42556..0000000 --- a/tests/test_parser.py +++ /dev/null @@ -1,89 +0,0 @@ -import json - -import pytest - -from retrack.engine.parser import Parser -from retrack import nodes - - -@pytest.mark.parametrize( - "data_filename,expected_tokens", - [ - ( - "tests/resources/age-negative.json", - { - "start": ["0"], - "input": ["2", "13"], - "constant": ["3", "14"], - "check": ["4", "15"], - "if": ["6", "16"], - "bool": ["9", "17", "18"], - "output": ["10", "19", "20"], - }, - ) - ], -) -def test_parser_extract(data_filename, expected_tokens): - with open(data_filename) as f: - input_data = json.load(f) - - parser = Parser( - input_data, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - assert parser.components_registry.indexes_by_name_map == expected_tokens - - -def test_parser_with_unknown_node(): - with pytest.raises(ValueError): - Parser( - {"nodes": {"1": {"name": "Unknown"}}}, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - - -def test_parser_invalid_input_data(): - with pytest.raises(TypeError): - Parser( - "invalid data", - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - - -def test_parser_no_nodes(): - with pytest.raises(ValueError): - Parser( - {}, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - - -def test_parser_invalid_nodes(): - with pytest.raises(TypeError): - Parser( - {"nodes": "invalid nodes"}, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - - -def test_parser_node_no_name(): - with pytest.raises(ValueError): - Parser( - {"nodes": {"1": {}}}, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) - - -def test_parser_node_invalid_name(): - with pytest.raises(TypeError): - Parser( - {"nodes": {"1": {"name": 1}}}, - nodes_registry=nodes.registry(), - dynamic_nodes_registry=nodes.dynamic_nodes_registry(), - ) From b0267aad07200ff374068ef667a82119c57d7b3f Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 20:37:04 -0300 Subject: [PATCH 09/17] Refactor RuleExecutor class and add Execution --- retrack/engine/rule.py | 123 ++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 56 deletions(-) diff --git a/retrack/engine/rule.py b/retrack/engine/rule.py index 7b33059..8fa91d0 100644 --- a/retrack/engine/rule.py +++ b/retrack/engine/rule.py @@ -15,10 +15,44 @@ from retrack.utils import constants +class Execution: + def __init__(self, states: pd.DataFrame): + self.states = states + + def set_state_data( + self, column: str, value: typing.Any, filter_by: typing.Any = None + ): + if filter_by is None: + self.states[column] = value + else: + self.states.loc[filter_by, column] = value + + def get_state_data( + self, column: str, constants: dict, filter_by: typing.Any = None + ): + if column in constants: + return constants[column] + + if filter_by is None: + return self.states[column] + + return self.states.loc[filter_by, column] + + @classmethod + def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): + state_df = pd.DataFrame([]) + for node_id, input_name in input_columns.items(): + state_df[node_id] = validated_payload[input_name] + + state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan + state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan + + return cls(state_df) + + class RuleExecutor: def __init__(self, rule: "Rule"): self._rule = rule - self._validated_payload = None self.reset() self._set_constants() self._set_input_columns() @@ -38,10 +72,6 @@ def request_manager(self) -> RequestManager: def request_model(self) -> pydantic.BaseModel: return self._request_manager.model - @property - def states(self) -> pd.DataFrame: - return self._states - @property def filters(self) -> dict: return self._filters @@ -71,7 +101,6 @@ def _set_input_columns(self): } def reset(self): - self._states = None self._filters = {} def __set_output_connection_filters( @@ -91,25 +120,11 @@ def __set_output_connection_filters( self._filters[output_connection_id] & filter ) - def _create_initial_state_from_payload( - self, payload_df: pd.DataFrame - ) -> pd.DataFrame: - """Create initial state from payload. This is the first step of the runner.""" - self._validated_payload = self.request_manager.validate( - payload_df.reset_index(drop=True) - ) - - state_df = pd.DataFrame([]) - for node_id, input_name in self.input_columns.items(): - state_df[node_id] = self._validated_payload[input_name] - - state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan - state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan - - return state_df - def __get_input_params( - self, node_id: str, node_dict: dict, current_node_filter: pd.Series + self, + node_dict: dict, + current_node_filter: pd.Series, + execution: Execution, ) -> dict: input_params = {} @@ -118,30 +133,15 @@ def __get_input_params( continue for connection in connections["connections"]: - input_params[connector_name] = self.__get_state_data( - f"{connection['node']}@{connection['output']}", current_node_filter + input_params[connector_name] = execution.get_state_data( + f"{connection['node']}@{connection['output']}", + constants=self.constants, + filter_by=current_node_filter, ) return input_params - def __set_state_data( - self, column: str, value: typing.Any, filter_by: typing.Any = None - ): - if filter_by is None: - self._states[column] = value - else: - self._states.loc[filter_by, column] = value - - def __get_state_data(self, column: str, filter_by: typing.Any = None): - if column in self._constants: - return self._constants[column] - - if filter_by is None: - return self._states[column] - - return self._states.loc[filter_by, column] - - def __run_node(self, node_id: str): + def __run_node(self, node_id: str, execution: Execution): current_node_filter = self._filters.get(node_id, None) # if there is a filter, we need to set the children nodes to receive filtered data self.__set_output_connection_filters(node_id, current_node_filter) @@ -152,7 +152,9 @@ def __run_node(self, node_id: str): return input_params = self.__get_input_params( - node_id, node.model_dump(by_alias=True), current_node_filter + node.model_dump(by_alias=True), + current_node_filter, + execution=execution, ) output = node.run(**input_params) @@ -161,14 +163,22 @@ def __run_node(self, node_id: str): output_name == constants.OUTPUT_REFERENCE_COLUMN or output_name == constants.OUTPUT_MESSAGE_REFERENCE_COLUMN ): # Setting output values - self.__set_state_data(output_name, output_value, current_node_filter) + execution.set_state_data(output_name, output_value, current_node_filter) elif output_name.endswith(constants.FILTER_SUFFIX): # Setting filters self.__set_output_connection_filters(node_id, output_value, output_name) else: # Setting node outputs to be used as inputs by other nodes - self.__set_state_data( - f"{node_id}@{output_name}", output_value, current_node_filter + execution.set_state_data( + f"{node_id}@{output_name}", + output_value, + filter_by=current_node_filter, ) + def validate_payload(self, payload_df: pd.DataFrame): + if not isinstance(payload_df, pd.DataFrame): + raise ValueError("payload_df must be a pandas.DataFrame") + + return self.request_manager.validate(payload_df.reset_index(drop=True)) + def execute( self, payload_df: pd.DataFrame, @@ -183,27 +193,28 @@ def execute( Returns: pd.DataFrame: The output of the flow. """ - if not isinstance(payload_df, pd.DataFrame): - raise ValueError("payload_df must be a pandas.DataFrame") - self.reset() - self._states = self._create_initial_state_from_payload(payload_df) + + execution = Execution.from_payload( + validated_payload=self.validate_payload(payload_df), + input_columns=self.input_columns, + ) for node_id in self.rule.execution_order: try: - self.__run_node(node_id) + self.__run_node(node_id, execution=execution) except Exception as e: raise Exception( f"Error running node {node_id} in {self.rule.name} with version {self.rule.version}" ) from e - if self.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: + if execution.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: break if return_all_states: - return self.states + return execution.states - return self.states[ + return execution.states[ [ constants.OUTPUT_REFERENCE_COLUMN, constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, From 66d244d587628b202939d9ced4387f8ee4407aef Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 21:07:47 -0300 Subject: [PATCH 10/17] Remove filters from RuleExecutor --- retrack/engine/rule.py | 61 +++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/retrack/engine/rule.py b/retrack/engine/rule.py index 8fa91d0..619c33a 100644 --- a/retrack/engine/rule.py +++ b/retrack/engine/rule.py @@ -16,8 +16,9 @@ class Execution: - def __init__(self, states: pd.DataFrame): + def __init__(self, states: pd.DataFrame, filters: dict = None): self.states = states + self.filters = filters or {} def set_state_data( self, column: str, value: typing.Any, filter_by: typing.Any = None @@ -38,6 +39,15 @@ def get_state_data( return self.states.loc[filter_by, column] + def update_filters(self, filter_value, output_connections: typing.List[str] = None): + for output_connection_id in output_connections: + if self.filters.get(output_connection_id, None) is None: + self.filters[output_connection_id] = filter_value + else: + self.filters[output_connection_id] = ( + self.filters[output_connection_id] & filter_value + ) + @classmethod def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): state_df = pd.DataFrame([]) @@ -53,7 +63,6 @@ def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): class RuleExecutor: def __init__(self, rule: "Rule"): self._rule = rule - self.reset() self._set_constants() self._set_input_columns() self._request_manager = RequestManager( @@ -72,10 +81,6 @@ def request_manager(self) -> RequestManager: def request_model(self) -> pydantic.BaseModel: return self._request_manager.model - @property - def filters(self) -> dict: - return self._filters - @property def constants(self) -> dict: return self._constants @@ -100,25 +105,20 @@ def _set_input_columns(self): for node in input_nodes } - def reset(self): - self._filters = {} - def __set_output_connection_filters( - self, node_id: str, filter: typing.Any, connector_filter=None + self, + node_id: str, + filter_value: typing.Any, + execution: Execution, + connector_filter=None, ): - if filter is not None: - output_connections = ( - self.rule.components_registry.get_node_output_connections( - node_id, connector_filter=connector_filter - ) - ) - for output_connection_id in output_connections: - if self._filters.get(output_connection_id, None) is None: - self._filters[output_connection_id] = filter - else: - self._filters[output_connection_id] = ( - self._filters[output_connection_id] & filter - ) + if filter_value is None: + return + + output_connections = self.rule.components_registry.get_node_output_connections( + node_id, connector_filter=connector_filter + ) + execution.update_filters(filter_value, output_connections=output_connections) def __get_input_params( self, @@ -142,9 +142,11 @@ def __get_input_params( return input_params def __run_node(self, node_id: str, execution: Execution): - current_node_filter = self._filters.get(node_id, None) + current_node_filter = execution.filters.get(node_id, None) # if there is a filter, we need to set the children nodes to receive filtered data - self.__set_output_connection_filters(node_id, current_node_filter) + self.__set_output_connection_filters( + node_id, current_node_filter, execution=execution + ) node = self.rule.components_registry.get(node_id) @@ -165,7 +167,12 @@ def __run_node(self, node_id: str, execution: Execution): ): # Setting output values execution.set_state_data(output_name, output_value, current_node_filter) elif output_name.endswith(constants.FILTER_SUFFIX): # Setting filters - self.__set_output_connection_filters(node_id, output_value, output_name) + self.__set_output_connection_filters( + node_id, + output_value, + execution=execution, + connector_filter=output_name, + ) else: # Setting node outputs to be used as inputs by other nodes execution.set_state_data( f"{node_id}@{output_name}", @@ -193,8 +200,6 @@ def execute( Returns: pd.DataFrame: The output of the flow. """ - self.reset() - execution = Execution.from_payload( validated_payload=self.validate_payload(payload_df), input_columns=self.input_columns, From 181a7b8c6370d972990a088befd819422704b077 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 21:33:25 -0300 Subject: [PATCH 11/17] Add base.py and executor.py to retrack/engine --- retrack/engine/base.py | 55 +++++++ retrack/engine/executor.py | 187 ++++++++++++++++++++++ retrack/engine/rule.py | 309 ++++++------------------------------- 3 files changed, 291 insertions(+), 260 deletions(-) create mode 100644 retrack/engine/base.py create mode 100644 retrack/engine/executor.py diff --git a/retrack/engine/base.py b/retrack/engine/base.py new file mode 100644 index 0000000..0345c41 --- /dev/null +++ b/retrack/engine/base.py @@ -0,0 +1,55 @@ +import pydantic +import typing +import pandas as pd +import numpy as np +from retrack.utils import constants + + +class RuleMetadata(pydantic.BaseModel): + name: typing.Optional[str] = None + version: str + + +class Execution: + def __init__(self, states: pd.DataFrame, filters: dict = None): + self.states = states + self.filters = filters or {} + + def set_state_data( + self, column: str, value: typing.Any, filter_by: typing.Any = None + ): + if filter_by is None: + self.states[column] = value + else: + self.states.loc[filter_by, column] = value + + def get_state_data( + self, column: str, constants: dict, filter_by: typing.Any = None + ): + if column in constants: + return constants[column] + + if filter_by is None: + return self.states[column] + + return self.states.loc[filter_by, column] + + def update_filters(self, filter_value, output_connections: typing.List[str] = None): + for output_connection_id in output_connections: + if self.filters.get(output_connection_id, None) is None: + self.filters[output_connection_id] = filter_value + else: + self.filters[output_connection_id] = ( + self.filters[output_connection_id] & filter_value + ) + + @classmethod + def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): + state_df = pd.DataFrame([]) + for node_id, input_name in input_columns.items(): + state_df[node_id] = validated_payload[input_name] + + state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan + state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan + + return cls(state_df) diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py new file mode 100644 index 0000000..e859b1e --- /dev/null +++ b/retrack/engine/executor.py @@ -0,0 +1,187 @@ +import pydantic +from retrack.utils.component_registry import ComponentRegistry + +import typing + + +import pandas as pd + +from retrack.engine.request_manager import RequestManager +from retrack.nodes.base import NodeKind, NodeMemoryType +from retrack.utils import constants +from retrack.engine.base import Execution, RuleMetadata + + +class RuleExecutor: + def __init__( + self, + components_registry: ComponentRegistry, + execution_order: typing.List[str], + metadata: RuleMetadata, + ): + self._components_registry = components_registry + self._execution_order = execution_order + self._metadata = metadata + + input_nodes = self._components_registry.get_by_kind(NodeKind.INPUT) + self._input_columns = { + f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name + for node in input_nodes + } + self._request_manager = RequestManager(input_nodes) + + self._set_constants() + + @property + def execution_order(self) -> typing.List[str]: + return self._execution_order + + @property + def metadata(self) -> RuleMetadata: + return self._metadata + + @property + def request_manager(self) -> RequestManager: + return self._request_manager + + @property + def request_model(self) -> pydantic.BaseModel: + return self._request_manager.model + + @property + def constants(self) -> dict: + return self._constants + + def _set_constants(self): + constant_nodes = self._components_registry.get_by_memory_type( + NodeMemoryType.CONSTANT + ) + self._constants = {} + for node in constant_nodes: + for output_connector_name, _ in node.outputs: + self._constants[f"{node.id}@{output_connector_name}"] = node.data.value + + @property + def input_columns(self) -> dict: + return self._input_columns + + def __set_output_connection_filters( + self, + node_id: str, + filter_value: typing.Any, + execution: Execution, + connector_filter=None, + ): + if filter_value is None: + return + + output_connections = self._components_registry.get_node_output_connections( + node_id, connector_filter=connector_filter + ) + execution.update_filters(filter_value, output_connections=output_connections) + + def __get_input_params( + self, + node_dict: dict, + current_node_filter: pd.Series, + execution: Execution, + ) -> dict: + input_params = {} + + for connector_name, connections in node_dict.get("inputs", {}).items(): + if connector_name.endswith(constants.NULL_SUFFIX): + continue + + for connection in connections["connections"]: + input_params[connector_name] = execution.get_state_data( + f"{connection['node']}@{connection['output']}", + constants=self.constants, + filter_by=current_node_filter, + ) + + return input_params + + def __run_node(self, node_id: str, execution: Execution): + current_node_filter = execution.filters.get(node_id, None) + # if there is a filter, we need to set the children nodes to receive filtered data + self.__set_output_connection_filters( + node_id, current_node_filter, execution=execution + ) + + node = self._components_registry.get(node_id) + + if node.memory_type == NodeMemoryType.CONSTANT: + return + + input_params = self.__get_input_params( + node.model_dump(by_alias=True), + current_node_filter, + execution=execution, + ) + output = node.run(**input_params) + + for output_name, output_value in output.items(): + if ( + output_name == constants.OUTPUT_REFERENCE_COLUMN + or output_name == constants.OUTPUT_MESSAGE_REFERENCE_COLUMN + ): # Setting output values + execution.set_state_data(output_name, output_value, current_node_filter) + elif output_name.endswith(constants.FILTER_SUFFIX): # Setting filters + self.__set_output_connection_filters( + node_id, + output_value, + execution=execution, + connector_filter=output_name, + ) + else: # Setting node outputs to be used as inputs by other nodes + execution.set_state_data( + f"{node_id}@{output_name}", + output_value, + filter_by=current_node_filter, + ) + + def validate_payload(self, payload_df: pd.DataFrame): + if not isinstance(payload_df, pd.DataFrame): + raise ValueError("payload_df must be a pandas.DataFrame") + + return self.request_manager.validate(payload_df.reset_index(drop=True)) + + def execute( + self, + payload_df: pd.DataFrame, + return_all_states: bool = False, + ) -> pd.DataFrame: + """Executes the flow with the given payload. + + Args: + payload_df (pd.DataFrame): The payload to be used as input. + return_all_states (bool, optional): If True, returns all states. Defaults to False. + + Returns: + pd.DataFrame: The output of the flow. + """ + execution = Execution.from_payload( + validated_payload=self.validate_payload(payload_df), + input_columns=self.input_columns, + ) + + for node_id in self.execution_order: + try: + self.__run_node(node_id, execution=execution) + except Exception as e: + raise Exception( + f"Error running node {node_id} in {self.metadata.name} with version {self.metadata.version}" + ) from e + + if execution.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: + break + + if return_all_states: + return execution.states + + return execution.states[ + [ + constants.OUTPUT_REFERENCE_COLUMN, + constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, + ] + ] diff --git a/retrack/engine/rule.py b/retrack/engine/rule.py index 619c33a..0ebcce7 100644 --- a/retrack/engine/rule.py +++ b/retrack/engine/rule.py @@ -7,231 +7,13 @@ import typing -import numpy as np -import pandas as pd +from retrack.engine.base import RuleMetadata +from retrack.engine.executor import RuleExecutor -from retrack.engine.request_manager import RequestManager -from retrack.nodes.base import NodeKind, NodeMemoryType -from retrack.utils import constants - -class Execution: - def __init__(self, states: pd.DataFrame, filters: dict = None): - self.states = states - self.filters = filters or {} - - def set_state_data( - self, column: str, value: typing.Any, filter_by: typing.Any = None - ): - if filter_by is None: - self.states[column] = value - else: - self.states.loc[filter_by, column] = value - - def get_state_data( - self, column: str, constants: dict, filter_by: typing.Any = None - ): - if column in constants: - return constants[column] - - if filter_by is None: - return self.states[column] - - return self.states.loc[filter_by, column] - - def update_filters(self, filter_value, output_connections: typing.List[str] = None): - for output_connection_id in output_connections: - if self.filters.get(output_connection_id, None) is None: - self.filters[output_connection_id] = filter_value - else: - self.filters[output_connection_id] = ( - self.filters[output_connection_id] & filter_value - ) - - @classmethod - def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): - state_df = pd.DataFrame([]) - for node_id, input_name in input_columns.items(): - state_df[node_id] = validated_payload[input_name] - - state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan - state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan - - return cls(state_df) - - -class RuleExecutor: - def __init__(self, rule: "Rule"): - self._rule = rule - self._set_constants() - self._set_input_columns() - self._request_manager = RequestManager( - self._rule.components_registry.get_by_kind(NodeKind.INPUT) - ) - - @property - def rule(self) -> "Rule": - return self._rule - - @property - def request_manager(self) -> RequestManager: - return self._request_manager - - @property - def request_model(self) -> pydantic.BaseModel: - return self._request_manager.model - - @property - def constants(self) -> dict: - return self._constants - - def _set_constants(self): - constant_nodes = self.rule.components_registry.get_by_memory_type( - NodeMemoryType.CONSTANT - ) - self._constants = {} - for node in constant_nodes: - for output_connector_name, _ in node.outputs: - self._constants[f"{node.id}@{output_connector_name}"] = node.data.value - - @property - def input_columns(self) -> dict: - return self._input_columns - - def _set_input_columns(self): - input_nodes = self._rule.components_registry.get_by_kind(NodeKind.INPUT) - self._input_columns = { - f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name - for node in input_nodes - } - - def __set_output_connection_filters( - self, - node_id: str, - filter_value: typing.Any, - execution: Execution, - connector_filter=None, - ): - if filter_value is None: - return - - output_connections = self.rule.components_registry.get_node_output_connections( - node_id, connector_filter=connector_filter - ) - execution.update_filters(filter_value, output_connections=output_connections) - - def __get_input_params( - self, - node_dict: dict, - current_node_filter: pd.Series, - execution: Execution, - ) -> dict: - input_params = {} - - for connector_name, connections in node_dict.get("inputs", {}).items(): - if connector_name.endswith(constants.NULL_SUFFIX): - continue - - for connection in connections["connections"]: - input_params[connector_name] = execution.get_state_data( - f"{connection['node']}@{connection['output']}", - constants=self.constants, - filter_by=current_node_filter, - ) - - return input_params - - def __run_node(self, node_id: str, execution: Execution): - current_node_filter = execution.filters.get(node_id, None) - # if there is a filter, we need to set the children nodes to receive filtered data - self.__set_output_connection_filters( - node_id, current_node_filter, execution=execution - ) - - node = self.rule.components_registry.get(node_id) - - if node.memory_type == NodeMemoryType.CONSTANT: - return - - input_params = self.__get_input_params( - node.model_dump(by_alias=True), - current_node_filter, - execution=execution, - ) - output = node.run(**input_params) - - for output_name, output_value in output.items(): - if ( - output_name == constants.OUTPUT_REFERENCE_COLUMN - or output_name == constants.OUTPUT_MESSAGE_REFERENCE_COLUMN - ): # Setting output values - execution.set_state_data(output_name, output_value, current_node_filter) - elif output_name.endswith(constants.FILTER_SUFFIX): # Setting filters - self.__set_output_connection_filters( - node_id, - output_value, - execution=execution, - connector_filter=output_name, - ) - else: # Setting node outputs to be used as inputs by other nodes - execution.set_state_data( - f"{node_id}@{output_name}", - output_value, - filter_by=current_node_filter, - ) - - def validate_payload(self, payload_df: pd.DataFrame): - if not isinstance(payload_df, pd.DataFrame): - raise ValueError("payload_df must be a pandas.DataFrame") - - return self.request_manager.validate(payload_df.reset_index(drop=True)) - - def execute( - self, - payload_df: pd.DataFrame, - return_all_states: bool = False, - ) -> pd.DataFrame: - """Executes the flow with the given payload. - - Args: - payload_df (pd.DataFrame): The payload to be used as input. - return_all_states (bool, optional): If True, returns all states. Defaults to False. - - Returns: - pd.DataFrame: The output of the flow. - """ - execution = Execution.from_payload( - validated_payload=self.validate_payload(payload_df), - input_columns=self.input_columns, - ) - - for node_id in self.rule.execution_order: - try: - self.__run_node(node_id, execution=execution) - except Exception as e: - raise Exception( - f"Error running node {node_id} in {self.rule.name} with version {self.rule.version}" - ) from e - - if execution.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: - break - - if return_all_states: - return execution.states - - return execution.states[ - [ - constants.OUTPUT_REFERENCE_COLUMN, - constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, - ] - ] - - -class Rule(pydantic.BaseModel): +class Rule(RuleMetadata): model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) - name: typing.Optional[str] = None - version: str components_registry: ComponentRegistry execution_order: typing.List[str] _executor: RuleExecutor = None @@ -239,9 +21,16 @@ class Rule(pydantic.BaseModel): @property def executor(self) -> RuleExecutor: if self._executor is None: - self._executor = RuleExecutor(self) + self._executor = RuleExecutor( + self.components_registry, + self.execution_order, + self.as_metadata(), + ) return self._executor + def as_metadata(self) -> RuleMetadata: + return RuleMetadata(**self.model_dump()) + @classmethod def create( cls, @@ -253,7 +42,7 @@ def create( validate_version: bool = True, name: str = None, ): - components_registry = create_component_registry( + components_registry = Rule.create_component_registry( graph_data, nodes_registry, dynamic_nodes_registry, validator_registry ) version = graph.validate_version( @@ -276,45 +65,45 @@ def create( name=name, ) + @staticmethod + def create_component_registry( + graph_data: dict, + nodes_registry: Registry, + dynamic_nodes_registry: Registry, + validator_registry: Registry, + ) -> ComponentRegistry: + components_registry = ComponentRegistry() + graph_data = graph.validate_data(graph_data) + for node_id, node_metadata in graph_data["nodes"].items(): + if node_id in components_registry: + raise ValueError(f"Duplicate node id: {node_id}") + + node_name = node_metadata.get("name", None) + graph.check_node_name(node_name, node_id) + + node_name = node_name.lower() + + node_factory = dynamic_nodes_registry.get(node_name) + + if node_factory is not None: + validation_model = node_factory( + **node_metadata, + nodes_registry=nodes_registry, + dynamic_nodes_registry=dynamic_nodes_registry, + validator_registry=validator_registry, + rule_class=Rule, + ) + else: + validation_model = nodes_registry.get(node_name) -def create_component_registry( - graph_data: dict, - nodes_registry: Registry, - dynamic_nodes_registry: Registry, - validator_registry: Registry, -) -> ComponentRegistry: - components_registry = ComponentRegistry() - graph_data = graph.validate_data(graph_data) - for node_id, node_metadata in graph_data["nodes"].items(): - if node_id in components_registry: - raise ValueError(f"Duplicate node id: {node_id}") - - node_name = node_metadata.get("name", None) - graph.check_node_name(node_name, node_id) - - node_name = node_name.lower() - - node_factory = dynamic_nodes_registry.get(node_name) - - if node_factory is not None: - validation_model = node_factory( - **node_metadata, - nodes_registry=nodes_registry, - dynamic_nodes_registry=dynamic_nodes_registry, - validator_registry=validator_registry, - rule_class=Rule, - ) - else: - validation_model = nodes_registry.get(node_name) - - if validation_model is None: - raise ValueError(f"Unknown node name: {node_name}") + if validation_model is None: + raise ValueError(f"Unknown node name: {node_name}") - component = validation_model(**node_metadata) + component = validation_model(**node_metadata) - for input_node in component.generate_input_nodes(): - components_registry.register(input_node.id, input_node) + for input_node in component.generate_input_nodes(): + components_registry.register(input_node.id, input_node) - components_registry.register(node_id, validation_model(**node_metadata)) + components_registry.register(node_id, validation_model(**node_metadata)) - return components_registry + return components_registry From dab0e7d15a3767e6a42156292a1beb0777ceabae Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 22:07:20 -0300 Subject: [PATCH 12/17] Handling virtual connector inputs --- retrack/engine/base.py | 9 ++++++--- retrack/engine/executor.py | 7 +++++++ retrack/engine/rule.py | 6 +++--- retrack/nodes/base.py | 1 + retrack/nodes/connectors.py | 7 ++----- retrack/nodes/dynamic/flow.py | 23 ++++++++++++++--------- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/retrack/engine/base.py b/retrack/engine/base.py index 0345c41..97f961c 100644 --- a/retrack/engine/base.py +++ b/retrack/engine/base.py @@ -11,7 +11,10 @@ class RuleMetadata(pydantic.BaseModel): class Execution: - def __init__(self, states: pd.DataFrame, filters: dict = None): + def __init__( + self, payload: pd.DataFrame, states: pd.DataFrame, filters: dict = None + ): + self.payload = payload self.states = states self.filters = filters or {} @@ -47,9 +50,9 @@ def update_filters(self, filter_value, output_connections: typing.List[str] = No def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): state_df = pd.DataFrame([]) for node_id, input_name in input_columns.items(): - state_df[node_id] = validated_payload[input_name] + state_df[node_id] = validated_payload[input_name].copy() state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan - return cls(state_df) + return cls(payload=validated_payload, states=state_df) diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py index e859b1e..061487a 100644 --- a/retrack/engine/executor.py +++ b/retrack/engine/executor.py @@ -85,6 +85,7 @@ def __get_input_params( node_dict: dict, current_node_filter: pd.Series, execution: Execution, + include_payload: bool = False, ) -> dict: input_params = {} @@ -99,6 +100,10 @@ def __get_input_params( filter_by=current_node_filter, ) + if include_payload: + for input_column in execution.payload.columns: + input_params[input_column] = execution.payload[input_column] + return input_params def __run_node(self, node_id: str, execution: Execution): @@ -117,6 +122,8 @@ def __run_node(self, node_id: str, execution: Execution): node.model_dump(by_alias=True), current_node_filter, execution=execution, + include_payload=node.kind() == NodeKind.CONNECTOR + or node.kind() == NodeKind.FLOW, ) output = node.run(**input_params) diff --git a/retrack/engine/rule.py b/retrack/engine/rule.py index 0ebcce7..fde58cf 100644 --- a/retrack/engine/rule.py +++ b/retrack/engine/rule.py @@ -101,9 +101,9 @@ def create_component_registry( component = validation_model(**node_metadata) - for input_node in component.generate_input_nodes(): - components_registry.register(input_node.id, input_node) - components_registry.register(node_id, validation_model(**node_metadata)) + for input_node in component.generate_input_nodes(): + components_registry.register(input_node.id, input_node, overwrite=True) + return components_registry diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index 8e419e5..58425dc 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -17,6 +17,7 @@ class NodeKind(str, enum.Enum): CONNECTOR = "connector" START = "start" OTHER = "other" + FLOW = "flow" ############################################################### diff --git a/retrack/nodes/connectors.py b/retrack/nodes/connectors.py index d17e013..4ff3b58 100644 --- a/retrack/nodes/connectors.py +++ b/retrack/nodes/connectors.py @@ -6,15 +6,12 @@ class VirtualConnector(Input): def kind(self) -> NodeKind: - return NodeKind.INPUT + return NodeKind.CONNECTOR def generate_input_nodes(self) -> typing.List[Input]: - return [] + return [Input(**self.model_dump(by_alias=True))] class BaseConnector(VirtualConnector): - def kind(self) -> NodeKind: - return NodeKind.CONNECTOR - def generate_input_nodes(self) -> typing.List[Input]: raise NotImplementedError() diff --git a/retrack/nodes/dynamic/flow.py b/retrack/nodes/dynamic/flow.py index e3e453b..8cb4ea6 100644 --- a/retrack/nodes/dynamic/flow.py +++ b/retrack/nodes/dynamic/flow.py @@ -5,7 +5,7 @@ import pandas as pd import pydantic -from retrack.nodes.base import InputConnectionModel, OutputConnectionModel +from retrack.nodes.base import InputConnectionModel, OutputConnectionModel, NodeKind from retrack.nodes.dynamic.base import BaseDynamicIOModel, BaseDynamicNode from retrack.utils.registry import Registry @@ -57,20 +57,25 @@ def flow_factory( class FlowV0(BaseFlowV0Model): def run(self, **kwargs) -> typing.Dict[str, typing.Any]: - inputs_in_kwargs = {} - + print(kwargs.keys()) + input_args = {} for name, value in kwargs.items(): if name.startswith("input_"): - inputs_in_kwargs[name[len("input_") :]] = value - elif name.startswith("payload_"): - inputs_in_kwargs[name[len("payload_") :]] = value + name = name[len("input_") :] + + input_args[name] = value - response = rule_instance.executor.execute(pd.DataFrame(inputs_in_kwargs)) + response = rule_instance.executor.execute(pd.DataFrame(input_args)) return {"output_value": response["output"].values} def generate_input_nodes(self): - # TODO: check inputs that do not have input nodes - return [] + input_nodes = [] + for component in rule_instance.components_registry.data.values(): + input_nodes.extend(component.generate_input_nodes()) + return input_nodes + + def kind(self) -> NodeKind: + return NodeKind.FLOW return FlowV0 From 7412957f5c2d6369731ac0d5976831b60abae32d Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 4 Dec 2023 22:08:10 -0300 Subject: [PATCH 13/17] Apply isort --- retrack/__init__.py | 2 +- retrack/engine/base.py | 6 ++++-- retrack/engine/constructor.py | 5 ++--- retrack/engine/executor.py | 8 +++----- retrack/engine/rule.py | 11 +++++------ retrack/nodes/__init__.py | 4 ++-- retrack/nodes/base.py | 3 +-- retrack/nodes/check.py | 3 +-- retrack/nodes/datetime.py | 3 +-- retrack/nodes/dynamic/flow.py | 5 ++--- retrack/nodes/math.py | 3 +-- retrack/utils/component_registry.py | 1 + retrack/utils/graph.py | 5 +++-- tests/test_engine/test_executor.py | 2 +- tests/test_nodes/test_connectors.py | 3 ++- 15 files changed, 30 insertions(+), 34 deletions(-) diff --git a/retrack/__init__.py b/retrack/__init__.py index b3cc423..1b0cb7b 100644 --- a/retrack/__init__.py +++ b/retrack/__init__.py @@ -1,5 +1,5 @@ -from retrack.engine.rule import Rule from retrack.engine.constructor import from_json +from retrack.engine.rule import Rule from retrack.nodes import registry as nodes_registry from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel diff --git a/retrack/engine/base.py b/retrack/engine/base.py index 97f961c..a02ba35 100644 --- a/retrack/engine/base.py +++ b/retrack/engine/base.py @@ -1,7 +1,9 @@ -import pydantic import typing -import pandas as pd + import numpy as np +import pandas as pd +import pydantic + from retrack.utils import constants diff --git a/retrack/engine/constructor.py b/retrack/engine/constructor.py index b9c1d6b..b6313c1 100644 --- a/retrack/engine/constructor.py +++ b/retrack/engine/constructor.py @@ -1,10 +1,9 @@ -import typing - import json +import typing +from retrack import nodes from retrack.engine.rule import Rule, RuleExecutor from retrack.utils import registry -from retrack import nodes def from_json( diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py index 061487a..b847f7e 100644 --- a/retrack/engine/executor.py +++ b/retrack/engine/executor.py @@ -1,15 +1,13 @@ -import pydantic -from retrack.utils.component_registry import ComponentRegistry - import typing - import pandas as pd +import pydantic +from retrack.engine.base import Execution, RuleMetadata from retrack.engine.request_manager import RequestManager from retrack.nodes.base import NodeKind, NodeMemoryType from retrack.utils import constants -from retrack.engine.base import Execution, RuleMetadata +from retrack.utils.component_registry import ComponentRegistry class RuleExecutor: diff --git a/retrack/engine/rule.py b/retrack/engine/rule.py index fde58cf..80df839 100644 --- a/retrack/engine/rule.py +++ b/retrack/engine/rule.py @@ -1,14 +1,13 @@ -import pydantic -from retrack.utils.component_registry import ComponentRegistry -from retrack.utils.registry import Registry - -from retrack.utils import graph -from retrack import validators import typing +import pydantic +from retrack import validators from retrack.engine.base import RuleMetadata from retrack.engine.executor import RuleExecutor +from retrack.utils import graph +from retrack.utils.component_registry import ComponentRegistry +from retrack.utils.registry import Registry class Rule(RuleMetadata): diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index a521ef0..e13b715 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -1,5 +1,6 @@ from retrack.nodes.base import BaseNode from retrack.nodes.check import Check +from retrack.nodes.connectors import BaseConnector, VirtualConnector from retrack.nodes.constants import Bool, Constant, IntervalCatV0, List from retrack.nodes.contains import Contains from retrack.nodes.datetime import CurrentYear @@ -9,14 +10,13 @@ from retrack.nodes.endswithany import EndsWithAny from retrack.nodes.inputs import Input from retrack.nodes.logic import And, Not, Or +from retrack.nodes.lowercase import LowerCase from retrack.nodes.match import If from retrack.nodes.math import AbsoluteValue, Math, Round from retrack.nodes.outputs import Output from retrack.nodes.start import Start from retrack.nodes.startswith import StartsWith from retrack.nodes.startswithany import StartsWithAny -from retrack.nodes.connectors import BaseConnector, VirtualConnector -from retrack.nodes.lowercase import LowerCase from retrack.utils.registry import Registry _registry = Registry() diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index 58425dc..0ef84d2 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -1,6 +1,5 @@ -import typing - import enum +import typing import pydantic diff --git a/retrack/nodes/check.py b/retrack/nodes/check.py index 12f1b8a..5467441 100644 --- a/retrack/nodes/check.py +++ b/retrack/nodes/check.py @@ -1,6 +1,5 @@ -import typing - import enum +import typing import pandas as pd import pydantic diff --git a/retrack/nodes/datetime.py b/retrack/nodes/datetime.py index 0829e1c..e5f3d28 100644 --- a/retrack/nodes/datetime.py +++ b/retrack/nodes/datetime.py @@ -1,6 +1,5 @@ -import typing - import datetime as dt +import typing import pandas as pd import pydantic diff --git a/retrack/nodes/dynamic/flow.py b/retrack/nodes/dynamic/flow.py index 8cb4ea6..5222817 100644 --- a/retrack/nodes/dynamic/flow.py +++ b/retrack/nodes/dynamic/flow.py @@ -1,11 +1,10 @@ -import typing - import json +import typing import pandas as pd import pydantic -from retrack.nodes.base import InputConnectionModel, OutputConnectionModel, NodeKind +from retrack.nodes.base import InputConnectionModel, NodeKind, OutputConnectionModel from retrack.nodes.dynamic.base import BaseDynamicIOModel, BaseDynamicNode from retrack.utils.registry import Registry diff --git a/retrack/nodes/math.py b/retrack/nodes/math.py index 049a59d..9cd98a4 100644 --- a/retrack/nodes/math.py +++ b/retrack/nodes/math.py @@ -1,6 +1,5 @@ -import typing - import enum +import typing import pandas as pd import pydantic diff --git a/retrack/utils/component_registry.py b/retrack/utils/component_registry.py index 0657a94..4d87c41 100644 --- a/retrack/utils/component_registry.py +++ b/retrack/utils/component_registry.py @@ -1,4 +1,5 @@ import typing + from retrack.nodes.base import BaseNode from retrack.utils.registry import Registry diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py index 43985e6..74b10b6 100644 --- a/retrack/utils/graph.py +++ b/retrack/utils/graph.py @@ -1,9 +1,10 @@ +import hashlib import json from unidecode import unidecode -import hashlib -from retrack.utils.registry import Registry + from retrack.utils.component_registry import ComponentRegistry +from retrack.utils.registry import Registry def validate_version( diff --git a/tests/test_engine/test_executor.py b/tests/test_engine/test_executor.py index 7dda2e9..294e2fd 100644 --- a/tests/test_engine/test_executor.py +++ b/tests/test_engine/test_executor.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from retrack import Rule, nodes, from_json +from retrack import Rule, from_json, nodes @pytest.mark.parametrize( diff --git a/tests/test_nodes/test_connectors.py b/tests/test_nodes/test_connectors.py index 7215265..9fd1d00 100644 --- a/tests/test_nodes/test_connectors.py +++ b/tests/test_nodes/test_connectors.py @@ -1,6 +1,7 @@ -from retrack.nodes.connectors import BaseConnector import pytest +from retrack.nodes.connectors import BaseConnector + @pytest.fixture def connector_dict(): From eee6651bdee25a126895dbee60cb9a83c419d8d5 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Tue, 5 Dec 2023 10:46:44 -0300 Subject: [PATCH 14/17] Testing custom connectors --- retrack/__init__.py | 2 + retrack/nodes/__init__.py | 61 +++++++------- retrack/nodes/connectors.py | 3 + retrack/nodes/dynamic/__init__.py | 19 ++--- retrack/validators/__init__.py | 21 ++--- tests/test_engine/test_custom_connectors.py | 89 +++++++++++++++++++++ 6 files changed, 136 insertions(+), 59 deletions(-) create mode 100644 tests/test_engine/test_custom_connectors.py diff --git a/retrack/__init__.py b/retrack/__init__.py index 1b0cb7b..b552788 100644 --- a/retrack/__init__.py +++ b/retrack/__init__.py @@ -2,6 +2,7 @@ from retrack.engine.rule import Rule from retrack.nodes import registry as nodes_registry from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.nodes.connectors import BaseConnector __all__ = [ "Rule", @@ -10,4 +11,5 @@ "InputConnectionModel", "OutputConnectionModel", "nodes_registry", + "BaseConnector", ] diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index e13b715..2dcea04 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -19,45 +19,44 @@ from retrack.nodes.startswithany import StartsWithAny from retrack.utils.registry import Registry -_registry = Registry() - def registry() -> Registry: - return _registry + _registry = Registry() + _registry.register("Input", Input) + _registry.register( + "Connector", VirtualConnector + ) # By default, Connector is an Input + _registry.register( + "ConnectorV0", VirtualConnector + ) # By default, Connector is an Input + _registry.register("Start", Start) + _registry.register("Constant", Constant) + _registry.register("List", List) + _registry.register("Bool", Bool) + _registry.register("Output", Output) + _registry.register("Check", Check) + _registry.register("If", If) + _registry.register("And", And) + _registry.register("Or", Or) + _registry.register("Not", Not) + _registry.register("Math", Math) + _registry.register("Round", Round) + _registry.register("AbsoluteValue", AbsoluteValue) + _registry.register("StartsWith", StartsWith) + _registry.register("EndsWith", EndsWith) + _registry.register("StartsWithAny", StartsWithAny) + _registry.register("EndsWithAny", EndsWithAny) + _registry.register("Contains", Contains) + _registry.register("CurrentYear", CurrentYear) + _registry.register("IntervalCatV0", IntervalCatV0) + _registry.register("LowerCase", LowerCase) -def register(name: str, node: BaseNode) -> None: - registry().register(name, node) - + return _registry -register("Input", Input) -register("Connector", VirtualConnector) # By default, Connector is an Input -register("ConnectorV0", VirtualConnector) # By default, Connector is an Input -register("Start", Start) -register("Constant", Constant) -register("List", List) -register("Bool", Bool) -register("Output", Output) -register("Check", Check) -register("If", If) -register("And", And) -register("Or", Or) -register("Not", Not) -register("Math", Math) -register("Round", Round) -register("AbsoluteValue", AbsoluteValue) -register("StartsWith", StartsWith) -register("EndsWith", EndsWith) -register("StartsWithAny", StartsWithAny) -register("EndsWithAny", EndsWithAny) -register("Contains", Contains) -register("CurrentYear", CurrentYear) -register("IntervalCatV0", IntervalCatV0) -register("LowerCase", LowerCase) __all__ = [ "registry", - "register", "BaseNode", "dynamic_nodes_registry", "BaseDynamicNode", diff --git a/retrack/nodes/connectors.py b/retrack/nodes/connectors.py index 4ff3b58..33bd768 100644 --- a/retrack/nodes/connectors.py +++ b/retrack/nodes/connectors.py @@ -15,3 +15,6 @@ def generate_input_nodes(self) -> typing.List[Input]: class BaseConnector(VirtualConnector): def generate_input_nodes(self) -> typing.List[Input]: raise NotImplementedError() + + def run(self, **kwargs): # Keep the kwargs in the signature + raise NotImplementedError() diff --git a/retrack/nodes/dynamic/__init__.py b/retrack/nodes/dynamic/__init__.py index 5ae3890..1f89223 100644 --- a/retrack/nodes/dynamic/__init__.py +++ b/retrack/nodes/dynamic/__init__.py @@ -5,23 +5,14 @@ from retrack.nodes.dynamic.flow import flow_factory from retrack.utils.registry import Registry -_registry = Registry() - def registry() -> Registry: - return _registry + _registry = Registry() + _registry.register("CSVTableV0", csv_table_factory) + _registry.register("FlowV0", flow_factory) -def register( - name: str, - factory: typing.Callable[ - [typing.List[str], typing.List[str]], typing.Type[BaseDynamicNode] - ], -) -> None: - registry().register(name, factory) - + return _registry -register("CSVTableV0", csv_table_factory) -register("FlowV0", flow_factory) -__all__ = ["registry", "register", "BaseDynamicNode"] +__all__ = ["registry", "BaseDynamicNode"] diff --git a/retrack/validators/__init__.py b/retrack/validators/__init__.py index 8f0f759..010a025 100644 --- a/retrack/validators/__init__.py +++ b/retrack/validators/__init__.py @@ -3,22 +3,15 @@ from retrack.validators.check_is_dag import CheckIsDAG from retrack.validators.node_exists import NodeExistsValidator -_registry = Registry() - def registry() -> Registry: + _registry = Registry() + _registry.register( + "single_start_node_exists", + NodeExistsValidator("start", min_quantity=1, max_quantity=1), + ) + _registry.register("check_is_dag", CheckIsDAG()) return _registry -def register(name: str, validator: BaseValidator) -> None: - registry().register(name, validator) - - -register( - "single_start_node_exists", - NodeExistsValidator("start", min_quantity=1, max_quantity=1), -) -register("check_is_dag", CheckIsDAG()) - - -__all__ = ["registry", "register", "BaseValidator"] +__all__ = ["registry", "BaseValidator"] diff --git a/tests/test_engine/test_custom_connectors.py b/tests/test_engine/test_custom_connectors.py new file mode 100644 index 0000000..16d16c0 --- /dev/null +++ b/tests/test_engine/test_custom_connectors.py @@ -0,0 +1,89 @@ +import pytest +import pandas as pd +from retrack import from_json, BaseConnector, nodes_registry +from retrack.nodes.inputs import Input +import pandera as pa + + +@pytest.fixture +def custom_connector(): + class MyConnector(BaseConnector): + def run(self, feature_a: pd.Series, feature_b: pd.Series, **kwargs): + return {"output_value": feature_a.astype(float) + feature_b.astype(float)} + + def generate_input_nodes(self): + base_dict = self.model_dump(by_alias=True) + inputs = [] + for feature_name in ["feature_a", "feature_b"]: + base_dict["id"] = self.id + "_" + feature_name + base_dict["data"]["name"] = feature_name + inputs.append(Input(**base_dict)) + + return inputs + + return MyConnector + + +@pytest.mark.parametrize( + "filename, in_values, expected_out_values", + [ + ( + "connector-rule", + [ + {"feature_a": "1", "feature_b": "4", "multiplier": "1"}, + {"feature_a": "2", "feature_b": "1", "multiplier": "1"}, + {"feature_a": "3", "feature_b": "3", "multiplier": "1"}, + {"feature_a": "4", "feature_b": "0", "multiplier": "1"}, + ], + [ + {"output": 5.0, "message": None}, + {"output": 3.0, "message": None}, + {"output": 6.0, "message": None}, + {"output": 4.0, "message": None}, + ], + ), + ( + "rule-of-rules-with-connector", + [ + {"feature_a": "1", "feature_b": "4", "var": "1"}, + {"feature_a": "2", "feature_b": "1", "var": "1"}, + {"feature_a": "3", "feature_b": "3", "var": "1"}, + {"feature_a": "4", "feature_b": "0", "var": "1"}, + ], + [ + {"output": 5.0, "message": None}, + {"output": 3.0, "message": None}, + {"output": 6.0, "message": None}, + {"output": 4.0, "message": None}, + ], + ), + ], +) +def test_connectors_with_custom_code( + filename, in_values, expected_out_values, custom_connector +): + custom_registry = nodes_registry() + custom_registry.register("Connector", custom_connector, overwrite=True) + custom_registry.register("ConnectorV0", custom_connector, overwrite=True) + + runner = from_json( + f"tests/resources/{filename}.json", nodes_registry=custom_registry + ) + out_values = runner.execute(pd.DataFrame(in_values)) + + assert isinstance(out_values, pd.DataFrame) + assert out_values.to_dict(orient="records") == expected_out_values + + +def test_missing_input_for_custom_connectors(custom_connector): + custom_registry = nodes_registry() + custom_registry.register("Connector", custom_connector, overwrite=True) + custom_registry.register("ConnectorV0", custom_connector, overwrite=True) + + with pytest.raises(pa.errors.SchemaError): + runner = from_json( + "tests/resources/connector-rule.json", + nodes_registry=custom_registry, + ) + + runner.execute(pd.DataFrame([{"multiplier": "1", "prediction": "4"}])) From 2046b57f832907b9bba8de2fc4221ea93f847bb0 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Tue, 5 Dec 2023 10:48:54 -0300 Subject: [PATCH 15/17] Update version to 2.0.0 and remove unused import --- pyproject.toml | 2 +- retrack/nodes/dynamic/__init__.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4f43d81..38d3271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "retrack" -version = "1.1.0" +version = "2.0.0" description = "A business rules engine" authors = ["Gabriel Guarisa ", "Nathalia Trotte "] license = "MIT" diff --git a/retrack/nodes/dynamic/__init__.py b/retrack/nodes/dynamic/__init__.py index 1f89223..29c4fbf 100644 --- a/retrack/nodes/dynamic/__init__.py +++ b/retrack/nodes/dynamic/__init__.py @@ -1,5 +1,3 @@ -import typing - from retrack.nodes.dynamic.base import BaseDynamicNode from retrack.nodes.dynamic.csv_table import csv_table_factory from retrack.nodes.dynamic.flow import flow_factory From 2efe2f11506646eac842fcf4f7acfb852a79e36c Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 11 Dec 2023 10:33:30 -0300 Subject: [PATCH 16/17] Create exception types --- retrack/engine/base.py | 34 +++++++++++++++++++++ retrack/engine/executor.py | 34 ++++++++++++--------- retrack/engine/request_manager.py | 4 +++ retrack/utils/exceptions.py | 27 ++++++++++++++++ retrack/utils/graph.py | 15 +++++---- tests/test_engine/test_custom_connectors.py | 4 +-- 6 files changed, 95 insertions(+), 23 deletions(-) create mode 100644 retrack/utils/exceptions.py diff --git a/retrack/engine/base.py b/retrack/engine/base.py index a02ba35..2567939 100644 --- a/retrack/engine/base.py +++ b/retrack/engine/base.py @@ -58,3 +58,37 @@ def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan return cls(payload=validated_payload, states=state_df) + + @property + def result(self) -> pd.DataFrame: + return self.states[ + [ + constants.OUTPUT_REFERENCE_COLUMN, + constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, + ] + ] + + def has_ended(self) -> bool: + return self.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0 + + def to_dict(self) -> dict: + return { + "payload": self.payload.to_dict(), + "states": self.states.to_dict(), + "filters": {k: v.to_dict() for k, v in self.filters.items()}, + "result": self.result.to_dict(), + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + payload=pd.DataFrame(data["payload"]), + states=pd.DataFrame(data["states"]), + filters={k: pd.DataFrame(v) for k, v in data["filters"].items()}, + ) + + def __repr__(self) -> str: + return f"Execution({self.to_dict()})" + + def __str__(self) -> str: + return self.__repr__() diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py index b847f7e..57efc9e 100644 --- a/retrack/engine/executor.py +++ b/retrack/engine/executor.py @@ -6,7 +6,7 @@ from retrack.engine.base import Execution, RuleMetadata from retrack.engine.request_manager import RequestManager from retrack.nodes.base import NodeKind, NodeMemoryType -from retrack.utils import constants +from retrack.utils import constants, exceptions from retrack.utils.component_registry import ComponentRegistry @@ -147,14 +147,23 @@ def __run_node(self, node_id: str, execution: Execution): def validate_payload(self, payload_df: pd.DataFrame): if not isinstance(payload_df, pd.DataFrame): - raise ValueError("payload_df must be a pandas.DataFrame") + raise exceptions.ValidationException( + f"payload_df must be a pandas.DataFrame instead of {type(payload_df)}" + ) - return self.request_manager.validate(payload_df.reset_index(drop=True)) + try: + validated = self.request_manager.validate(payload_df.reset_index(drop=True)) + except Exception as e: + raise exceptions.ValidationException.from_metadata( + self.metadata, payload_df + ) from e + + return validated def execute( self, payload_df: pd.DataFrame, - return_all_states: bool = False, + return_execution: bool = False, ) -> pd.DataFrame: """Executes the flow with the given payload. @@ -174,19 +183,14 @@ def execute( try: self.__run_node(node_id, execution=execution) except Exception as e: - raise Exception( - f"Error running node {node_id} in {self.metadata.name} with version {self.metadata.version}" + raise exceptions.ExecutionException.from_metadata( + self.metadata, node_id ) from e - if execution.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: + if execution.has_ended(): break - if return_all_states: - return execution.states + if return_execution: + return execution - return execution.states[ - [ - constants.OUTPUT_REFERENCE_COLUMN, - constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, - ] - ] + return execution.result diff --git a/retrack/engine/request_manager.py b/retrack/engine/request_manager.py index a3b155f..ba4528f 100644 --- a/retrack/engine/request_manager.py +++ b/retrack/engine/request_manager.py @@ -31,6 +31,10 @@ def __init__(self, inputs: typing.List[BaseNode]): def inputs(self) -> typing.List[BaseNode]: return self._inputs + @property + def input_names(self) -> typing.List[str]: + return [input.data.name for input in self.inputs] + @inputs.setter def inputs(self, inputs: typing.List[BaseNode]): if not isinstance(inputs, list): diff --git a/retrack/utils/exceptions.py b/retrack/utils/exceptions.py new file mode 100644 index 0000000..a0d7bf3 --- /dev/null +++ b/retrack/utils/exceptions.py @@ -0,0 +1,27 @@ +import pandas as pd + + +class ExecutionException(Exception): + """Exception raised when an error occurs during execution of a command.""" + + @classmethod + def from_metadata(cls, metadata, node_id: str): + return cls( + f"Error executing node {node_id} from rule {metadata.name} version {metadata.version}" + ) + + +class ValidationException(Exception): + """Exception raised when an error occurs during validation of a command.""" + + @classmethod + def from_metadata(cls, metadata, payload_df: pd.DataFrame): + return cls( + f"Error validating rule {metadata.name} version {metadata.version} with payload {payload_df}" + ) + + +class InvalidVersionException(Exception): + """Exception raised when an invalid version is found.""" + + pass diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py index 74b10b6..ff4bbce 100644 --- a/retrack/utils/graph.py +++ b/retrack/utils/graph.py @@ -5,6 +5,7 @@ from retrack.utils.component_registry import ComponentRegistry from retrack.utils.registry import Registry +from retrack.utils import exceptions def validate_version( @@ -13,26 +14,28 @@ def validate_version( version = graph_data.get("version", None) graph_json_content = ( - json.dumps(graph_data["nodes"], ensure_ascii=False) - .replace(": ", ":") + json.dumps(graph_data["nodes"], ensure_ascii=False, separators=(",", ":")) .replace("\\", "") .replace('"', "") - .replace(", ", ",") ) graph_json_content = unidecode(graph_json_content, errors="strict") calculated_hash = hashlib.sha256(graph_json_content.encode()).hexdigest()[:10] if version is None: if raise_if_null_version: - raise ValueError("Missing version") + raise exceptions.InvalidVersionException( + "Missing version. " + + "Make sure to set a version in the graph data or set raise_if_null_version to False." + ) return f"{calculated_hash}.dynamic" file_version_hash = version.split(".")[0] if file_version_hash != calculated_hash and validate_version: - raise ValueError( - f"Invalid version. Graph data has changed and the hash is different: {calculated_hash} != {file_version_hash}" + raise exceptions.InvalidVersionException( + "Invalid version. " + + f"Graph data has changed and the hash is different: {calculated_hash} != {file_version_hash}" ) return version diff --git a/tests/test_engine/test_custom_connectors.py b/tests/test_engine/test_custom_connectors.py index 16d16c0..64e112c 100644 --- a/tests/test_engine/test_custom_connectors.py +++ b/tests/test_engine/test_custom_connectors.py @@ -2,7 +2,7 @@ import pandas as pd from retrack import from_json, BaseConnector, nodes_registry from retrack.nodes.inputs import Input -import pandera as pa +from retrack.utils import exceptions @pytest.fixture @@ -80,7 +80,7 @@ def test_missing_input_for_custom_connectors(custom_connector): custom_registry.register("Connector", custom_connector, overwrite=True) custom_registry.register("ConnectorV0", custom_connector, overwrite=True) - with pytest.raises(pa.errors.SchemaError): + with pytest.raises(exceptions.ValidationException): runner = from_json( "tests/resources/connector-rule.json", nodes_registry=custom_registry, From 05487246d4ad1230e5f7c64ec8e2194b33978540 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 11 Dec 2023 10:57:29 -0300 Subject: [PATCH 17/17] Update docs --- CONTRIBUTING.md | 4 +- Makefile | 3 + README.md | 29 +-- examples/age-check.json | 296 ------------------------------ examples/age_check.py | 24 --- retrack/__init__.py | 11 +- retrack/engine/constructor.py | 14 ++ retrack/engine/executor.py | 36 +++- retrack/nodes/__init__.py | 1 + retrack/nodes/dynamic/__init__.py | 3 + 10 files changed, 66 insertions(+), 355 deletions(-) delete mode 100644 examples/age-check.json delete mode 100644 examples/age_check.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9462629..0ccf2e0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,11 +29,9 @@ This will run the tests with [pytest](https://docs.pytest.org/en/latest/) and sh To format the code, you can use the command: ```bash -make formatting +make linting ``` -This will run the [isort](https://github.com/PyCQA/isort) and [black](https://github.com/psf/black) commands. - ### Releasing a new version To release a new version, you need to follow these steps: diff --git a/Makefile b/Makefile index a894f70..0d8f874 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,9 @@ formatting: check-formatting: poetry run ruff check . +.PHONY: linting +linting: formatting check-formatting + .PHONY: tests tests: poetry run pytest | tee pytest-coverage.txt diff --git a/README.md b/README.md index a5f820c..40e555e 100644 --- a/README.md +++ b/README.md @@ -26,30 +26,11 @@ pip install retrack ```python import retrack -runner = retrack.Runner.from_json("your-rule.json") +rule = retrack.from_json("rule.json") -response = runner.execute(input_data) +result = rule.execute(your_data_df) ``` -Or, if you want to create the parser and runner manually: - -```python -import retrack - -# Parse the rule/model -parser = retrack.Parser(rule) - -# Create a runner -runner = retrack.Runner(parser, name="your-rule") - -# Run the rule/model passing the data -runner.execute(data) -``` - -The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.request_model` property that is a pydantic model used to validate the data. - -Optionally you can name the rule by passing the `name` field to the `retrack.Runner` constructor. This is useful to identify the rule when exceptions are raised. - ### Creating a rule/model A rule is a set of conditions and actions that are executed when the conditions are met. The conditions are evaluated using the data passed to the runner. The actions are executed when the conditions are met. @@ -130,10 +111,10 @@ After creating the custom node, you need to register it in the nodes registry an import retrack # Register the custom node -retrack.nodes_registry.register_node("sum", SumNode) +custom_registry = retrack.nodes_registry() +custom_registry.register("sum", SumNode) -# Parse the rule/model -parser = Parser(rule, nodes_registry=retrack.nodes_registry) +rule = retrack.from_json("rule.json", nodes_registry=custom_registry) ``` ## Contributing diff --git a/examples/age-check.json b/examples/age-check.json deleted file mode 100644 index 63a47ec..0000000 --- a/examples/age-check.json +++ /dev/null @@ -1,296 +0,0 @@ -{ - "id": "demo@0.1.0", - "nodes": { - "0": { - "id": 0, - "data": {}, - "inputs": {}, - "outputs": { - "output_up_void": { - "connections": [ - { - "node": 2, - "input": "input_void", - "data": {} - } - ] - }, - "output_down_void": { - "connections": [ - { - "node": 3, - "input": "input_void", - "data": {} - } - ] - } - }, - "position": [ - -570.16015625, - -16.7578125 - ], - "name": "Start" - }, - "2": { - "id": 2, - "data": { - "name": "age", - "default": null - }, - "inputs": { - "input_void": { - "connections": [ - { - "node": 0, - "output": "output_up_void", - "data": {} - } - ] - } - }, - "outputs": { - "output_value": { - "connections": [ - { - "node": 4, - "input": "input_value_0", - "data": {} - } - ] - } - }, - "position": [ - -262.082231911288, - -229.52363816128795 - ], - "name": "Input" - }, - "3": { - "id": 3, - "data": { - "value": "18" - }, - "inputs": { - "input_void": { - "connections": [ - { - "node": 0, - "output": "output_down_void", - "data": {} - } - ] - } - }, - "outputs": { - "output_value": { - "connections": [ - { - "node": 4, - "input": "input_value_1", - "data": {} - } - ] - } - }, - "position": [ - -266.79444352384587, - 85.59488398537597 - ], - "name": "Constant" - }, - "4": { - "id": 4, - "data": { - "operator": ">=" - }, - "inputs": { - "input_value_0": { - "connections": [ - { - "node": 2, - "output": "output_value", - "data": {} - } - ] - }, - "input_value_1": { - "connections": [ - { - "node": 3, - "output": "output_value", - "data": {} - } - ] - } - }, - "outputs": { - "output_bool": { - "connections": [ - { - "node": 6, - "input": "input_bool", - "data": {} - } - ] - } - }, - "position": [ - 45.51953125, - -136.8515625 - ], - "name": "Check" - }, - "6": { - "id": 6, - "data": {}, - "inputs": { - "input_bool": { - "connections": [ - { - "node": 4, - "output": "output_bool", - "data": {} - } - ] - } - }, - "outputs": { - "output_then_filter": { - "connections": [ - { - "node": 9, - "input": "input_void", - "data": {} - } - ] - }, - "output_else_filter": { - "connections": [ - { - "node": 8, - "input": "input_void", - "data": {} - } - ] - } - }, - "position": [ - 387.98276806872417, - -127.24641593097007 - ], - "name": "If" - }, - "8": { - "id": 8, - "data": { - "value": null - }, - "inputs": { - "input_void": { - "connections": [ - { - "node": 6, - "output": "output_else_filter", - "data": {} - } - ] - } - }, - "outputs": { - "output_bool": { - "connections": [ - { - "node": 11, - "input": "input_value", - "data": {} - } - ] - } - }, - "position": [ - 696.7790861556878, - -16.077469330043932 - ], - "name": "Bool" - }, - "9": { - "id": 9, - "data": { - "value": true - }, - "inputs": { - "input_void": { - "connections": [ - { - "node": 6, - "output": "output_then_filter", - "data": {} - } - ] - } - }, - "outputs": { - "output_bool": { - "connections": [ - { - "node": 10, - "input": "input_value", - "data": {} - } - ] - } - }, - "position": [ - 693.7214037048345, - -194.67415220412568 - ], - "name": "Bool" - }, - "10": { - "id": 10, - "data": { - "message": "valid age" - }, - "inputs": { - "input_value": { - "connections": [ - { - "node": 9, - "output": "output_bool", - "data": {} - } - ] - } - }, - "outputs": {}, - "position": [ - 1015.5346416468075, - -247.2703893983769 - ], - "name": "Output" - }, - "11": { - "id": 11, - "data": { - "message": "invalid age" - }, - "inputs": { - "input_value": { - "connections": [ - { - "node": 8, - "output": "output_bool", - "data": {} - } - ] - } - }, - "outputs": {}, - "position": [ - 1008.826235840524, - -68.88453626006572 - ], - "name": "Output" - } - } -} \ No newline at end of file diff --git a/examples/age_check.py b/examples/age_check.py deleted file mode 100644 index 89bc99e..0000000 --- a/examples/age_check.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Example of using retrack to run a rule/model. -This example uses the age-check.json rule/model to check if a person is older than 18. -""" - -import json - -import retrack - -# Load the rule/model -with open("examples/age-check.json", "r") as f: - rule = json.load(f) - -# Parse the rule/model -parser = retrack.Parser(rule) - -# Create a runner -runner = retrack.Runner(parser) - -# Run the rule/model passing the data -in_values = [10, -10, 18, 19, 100] -print("Input values:", in_values) -out_values = runner([{"age": val} for val in in_values]) -print("Output values:") -print(out_values) diff --git a/retrack/__init__.py b/retrack/__init__.py index b552788..7ed33a3 100644 --- a/retrack/__init__.py +++ b/retrack/__init__.py @@ -1,15 +1,20 @@ -from retrack.engine.constructor import from_json from retrack.engine.rule import Rule -from retrack.nodes import registry as nodes_registry +from retrack.engine.constructor import from_json +from retrack.engine.executor import RuleExecutor +from retrack.engine.base import Execution +from retrack.nodes import registry as nodes_registry, dynamic_nodes_registry from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel from retrack.nodes.connectors import BaseConnector __all__ = [ "Rule", "from_json", + "RuleExecutor", + "Execution", + "nodes_registry", + "dynamic_nodes_registry", "BaseNode", "InputConnectionModel", "OutputConnectionModel", - "nodes_registry", "BaseConnector", ] diff --git a/retrack/engine/constructor.py b/retrack/engine/constructor.py index b6313c1..21377ed 100644 --- a/retrack/engine/constructor.py +++ b/retrack/engine/constructor.py @@ -13,6 +13,20 @@ def from_json( dynamic_nodes_registry: registry.Registry = nodes.dynamic_nodes_registry(), **kwargs, ) -> RuleExecutor: + """Create a Rule Executor from a json file or a dict. + + Args: + data (typing.Union[str, dict]): json file path or a dict. + name (str, optional): Rule name. Defaults to None. + nodes_registry (registry.Registry, optional): Nodes registry. Defaults to nodes.registry(). + dynamic_nodes_registry (registry.Registry, optional): Dynamic nodes registry. Defaults to nodes.dynamic_nodes_registry(). + + Raises: + ValueError: If data is not a dict or a json file path. + + Returns: + RuleExecutor: Rule executor. + """ if isinstance(data, str) and data.endswith(".json"): if name is None: name = data diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py index 57efc9e..4ff790b 100644 --- a/retrack/engine/executor.py +++ b/retrack/engine/executor.py @@ -17,6 +17,17 @@ def __init__( execution_order: typing.List[str], metadata: RuleMetadata, ): + """Class that executes a rule. + + Args: + components_registry (ComponentRegistry): Components registry. + execution_order (typing.List[str]): Execution order. + metadata (RuleMetadata): Rule metadata. + + Raises: + exceptions.ExecutionException: If there is an error during execution. + exceptions.ValidationException: If there is an error during validation. + """ self._components_registry = components_registry self._execution_order = execution_order self._metadata = metadata @@ -146,6 +157,17 @@ def __run_node(self, node_id: str, execution: Execution): ) def validate_payload(self, payload_df: pd.DataFrame): + """Validates the payload. + + Args: + payload_df (pd.DataFrame): The payload to be validated. + + Raises: + exceptions.ValidationException: If there is an error during validation. + + Returns: + pd.DataFrame: The validated payload. + """ if not isinstance(payload_df, pd.DataFrame): raise exceptions.ValidationException( f"payload_df must be a pandas.DataFrame instead of {type(payload_df)}" @@ -164,15 +186,19 @@ def execute( self, payload_df: pd.DataFrame, return_execution: bool = False, - ) -> pd.DataFrame: - """Executes the flow with the given payload. + ) -> typing.Union[pd.DataFrame, Execution]: + """Executes the rule. Args: - payload_df (pd.DataFrame): The payload to be used as input. - return_all_states (bool, optional): If True, returns all states. Defaults to False. + payload_df (pd.DataFrame): The payload to be executed. + return_execution (bool, optional): If True, returns the execution object. Defaults to False. + + Raises: + exceptions.ExecutionException: If there is an error during execution. + exceptions.ValidationException: If there is an error during validation. Returns: - pd.DataFrame: The output of the flow. + typing.Union[pd.DataFrame, Execution]: The execution result. """ execution = Execution.from_payload( validated_payload=self.validate_payload(payload_df), diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index 2dcea04..f727d26 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -21,6 +21,7 @@ def registry() -> Registry: + """Create a registry with all the nodes available in the library.""" _registry = Registry() _registry.register("Input", Input) diff --git a/retrack/nodes/dynamic/__init__.py b/retrack/nodes/dynamic/__init__.py index 29c4fbf..2f22a37 100644 --- a/retrack/nodes/dynamic/__init__.py +++ b/retrack/nodes/dynamic/__init__.py @@ -5,6 +5,9 @@ def registry() -> Registry: + """Create a registry with all the dynamic nodes available in the library. + + A dynamic node is a node that is not explicitly defined in the nodes registry, but is created dynamically from a factory function.""" _registry = Registry() _registry.register("CSVTableV0", csv_table_factory)