Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global Constants #33

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "retrack"
version = "2.7.1"
version = "2.8.0"
description = "A business rules engine"
authors = ["Gabriel Guarisa <[email protected]>"]
license = "MIT"
Expand Down
22 changes: 18 additions & 4 deletions retrack/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
import numpy as np
import pandas as pd

from retrack.utils import constants
from retrack.utils import constants, registry
from retrack.engine.schemas import ExecutionSchema


class Execution:
def __init__(
self, payload: pd.DataFrame, states: pd.DataFrame, filters: dict = None
self,
payload: pd.DataFrame,
states: pd.DataFrame,
filters: dict = None,
global_constants: registry.Registry = None,
):
self.payload = payload
self.states = states
self.filters = filters or {}
self.global_constants = global_constants

def set_state_data(
self, column: str, value: typing.Any, filter_by: typing.Any = None
Expand Down Expand Up @@ -44,15 +49,24 @@ def update_filters(self, filter_value, output_connections: typing.List[str] = No
)

@classmethod
def from_payload(cls, validated_payload: pd.DataFrame, input_columns: dict):
def from_payload(
cls,
validated_payload: pd.DataFrame,
input_columns: dict,
global_constants: registry.Registry = None,
):
state_df = pd.DataFrame([])
for node_id, input_name in input_columns.items():
state_df[node_id] = validated_payload[input_name].copy()

state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan
state_df[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN] = np.nan

return cls(payload=validated_payload, states=state_df)
return cls(
payload=validated_payload,
states=state_df,
global_constants=global_constants,
)

@property
def result(self) -> pd.DataFrame:
Expand Down
10 changes: 9 additions & 1 deletion retrack/engine/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from retrack.engine.schemas import RuleMetadata
from retrack.engine.request_manager import RequestManager
from retrack.nodes.base import NodeKind, NodeMemoryType
from retrack.utils import constants, exceptions
from retrack.utils import constants, exceptions, registry
from retrack.utils.component_registry import ComponentRegistry


Expand Down Expand Up @@ -124,6 +124,11 @@ def __run_node(self, node_id: str, execution: Execution):
input_params = self.__get_input_params(
node.model_dump(by_alias=True), current_node_filter, execution=execution
)

# TODO: Remove this condition after adding support for kwargs in the run method for all nodes
if node.kind() == NodeKind.CONNECTOR:
input_params["global_constants"] = execution.global_constants

output = node.run(**input_params)

for output_name, output_value in output.items():
Expand Down Expand Up @@ -193,6 +198,7 @@ def execute(
self,
payload_df: pd.DataFrame,
debug_mode: bool = False,
global_constants: typing.Optional[registry.Registry] = None,
) -> typing.Union[
pd.DataFrame, typing.Tuple[Execution, typing.Optional[Exception]]
]:
Expand All @@ -201,6 +207,7 @@ def execute(
Args:
payload_df (pd.DataFrame): The payload to be executed.
debug_mode (bool, optional): If True, runs the rule in debug mode and returns the exception, if any. Defaults to False.
global_constants (registry.Registry, optional): Global constants to be used during execution. Defaults to None.

Raises:
exceptions.ExecutionException: If there is an error during execution.
Expand All @@ -220,6 +227,7 @@ def execute(
execution = Execution.from_payload(
validated_payload=validated_payload,
input_columns=self.input_columns,
global_constants=global_constants,
)

for node_id in self.execution_order:
Expand Down
2 changes: 1 addition & 1 deletion retrack/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def cast_empty_string_to_none(v: str, info: pydantic.ValidationInfo) -> typing.A

class OutputConnectionItemModel(pydantic.BaseModel):
node: CastedToStringType
input_: str = pydantic.Field(alias="input")
input: str = pydantic.Field(alias="input")


class InputConnectionItemModel(pydantic.BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions retrack/nodes/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ class BaseConnector(Input):

def kind(self) -> NodeKind:
return NodeKind.CONNECTOR

def run(self, **kwargs):
return {}
3 changes: 3 additions & 0 deletions retrack/nodes/dynamic/conditional_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ class ConditionalConnector(BaseModel):
def kind(self) -> NodeKind:
return NodeKind.CONNECTOR

def run(self, **kwargs):
return {}

return ConditionalConnector
Loading