diff --git a/pyproject.toml b/pyproject.toml index 3a9f108..9341a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "retrack" -version = "2.7.1" +version = "2.8.0" description = "A business rules engine" authors = ["Gabriel Guarisa "] license = "MIT" diff --git a/retrack/engine/base.py b/retrack/engine/base.py index ade2cd3..571a8e2 100644 --- a/retrack/engine/base.py +++ b/retrack/engine/base.py @@ -3,17 +3,22 @@ import numpy as np import pandas as pd -from retrack.utils import constants +from retrack.utils import constants, registry from retrack.engine.schemas import ExecutionSchema class Execution: def __init__( - self, payload: pd.DataFrame, states: pd.DataFrame, filters: dict = None + self, + payload: pd.DataFrame, + states: pd.DataFrame, + filters: dict = None, + global_constants: registry.Registry = None, ): self.payload = payload self.states = states self.filters = filters or {} + self.global_constants = global_constants def set_state_data( self, column: str, value: typing.Any, filter_by: typing.Any = None @@ -44,7 +49,12 @@ def update_filters(self, filter_value, output_connections: typing.List[str] = No ) @classmethod - def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): + def from_payload( + cls, + validated_payload: pd.DataFrame, + input_columns: dict, + global_constants: registry.Registry = None, + ): state_df = pd.DataFrame([]) for node_id, input_name in input_columns.items(): state_df[node_id] = validated_payload[input_name].copy() @@ -52,7 +62,11 @@ def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict): state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan - return cls(payload=validated_payload, states=state_df) + return cls( + payload=validated_payload, + states=state_df, + global_constants=global_constants, + ) @property def result(self) -> pd.DataFrame: diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py index 8c89d42..0c67987 100644 --- a/retrack/engine/executor.py +++ b/retrack/engine/executor.py @@ -8,7 +8,7 @@ from retrack.engine.schemas import RuleMetadata from retrack.engine.request_manager import RequestManager from retrack.nodes.base import NodeKind, NodeMemoryType -from retrack.utils import constants, exceptions +from retrack.utils import constants, exceptions, registry from retrack.utils.component_registry import ComponentRegistry @@ -124,6 +124,11 @@ def __run_node(self, node_id: str, execution: Execution): input_params = self.__get_input_params( node.model_dump(by_alias=True), current_node_filter, execution=execution ) + + # TODO: Remove this condition after adding support for kwargs in the run method for all nodes + if node.kind() == NodeKind.CONNECTOR: + input_params["global_constants"] = execution.global_constants + output = node.run(**input_params) for output_name, output_value in output.items(): @@ -193,6 +198,7 @@ def execute( self, payload_df: pd.DataFrame, debug_mode: bool = False, + global_constants: typing.Optional[registry.Registry] = None, ) -> typing.Union[ pd.DataFrame, typing.Tuple[Execution, typing.Optional[Exception]] ]: @@ -201,6 +207,7 @@ def execute( Args: payload_df (pd.DataFrame): The payload to be executed. debug_mode (bool, optional): If True, runs the rule in debug mode and returns the exception, if any. Defaults to False. + global_constants (registry.Registry, optional): Global constants to be used during execution. Defaults to None. Raises: exceptions.ExecutionException: If there is an error during execution. @@ -220,6 +227,7 @@ def execute( execution = Execution.from_payload( validated_payload=validated_payload, input_columns=self.input_columns, + global_constants=global_constants, ) for node_id in self.execution_order: diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index 534ad71..e53ed23 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -57,7 +57,7 @@ def cast_empty_string_to_none(v: str, info: pydantic.ValidationInfo) -> typing.A class OutputConnectionItemModel(pydantic.BaseModel): node: CastedToStringType - input_: str = pydantic.Field(alias="input") + input: str = pydantic.Field(alias="input") class InputConnectionItemModel(pydantic.BaseModel): diff --git a/retrack/nodes/connectors.py b/retrack/nodes/connectors.py index eda6432..44ee660 100644 --- a/retrack/nodes/connectors.py +++ b/retrack/nodes/connectors.py @@ -12,3 +12,6 @@ class BaseConnector(Input): def kind(self) -> NodeKind: return NodeKind.CONNECTOR + + def run(self, **kwargs): + return {} diff --git a/retrack/nodes/dynamic/conditional_connector.py b/retrack/nodes/dynamic/conditional_connector.py index 12a34ee..da041bd 100644 --- a/retrack/nodes/dynamic/conditional_connector.py +++ b/retrack/nodes/dynamic/conditional_connector.py @@ -44,4 +44,7 @@ class ConditionalConnector(BaseModel): def kind(self) -> NodeKind: return NodeKind.CONNECTOR + def run(self, **kwargs): + return {} + return ConditionalConnector