Skip to content

Commit

Permalink
Rename global_constants to context
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielguarisa committed Sep 10, 2024
1 parent 0a8bacf commit dea1331
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
8 changes: 4 additions & 4 deletions retrack/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def __init__(
payload: pd.DataFrame,
states: pd.DataFrame,
filters: dict = None,
global_constants: registry.Registry = None,
context: registry.Registry = None,
):
self.payload = payload
self.states = states
self.filters = filters or {}
self.global_constants = global_constants
self.context = context

def set_state_data(
self, column: str, value: typing.Any, filter_by: typing.Any = None
Expand Down Expand Up @@ -53,7 +53,7 @@ def from_payload(
cls,
validated_payload: pd.DataFrame,
input_columns: dict,
global_constants: registry.Registry = None,
context: registry.Registry = None,
):
state_df = pd.DataFrame([])
for node_id, input_name in input_columns.items():
Expand All @@ -65,7 +65,7 @@ def from_payload(
return cls(
payload=validated_payload,
states=state_df,
global_constants=global_constants,
context=context,
)

@property
Expand Down
16 changes: 11 additions & 5 deletions retrack/engine/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __set_output_connection_filters(
execution: Execution,
connector_filter=None,
):
"""If there is a filter, we need to set the children nodes to receive filtered data"""
if filter_value is None:
return

Expand Down Expand Up @@ -111,7 +112,7 @@ def __get_input_params(

def __run_node(self, node_id: str, execution: Execution):
current_node_filter = execution.filters.get(node_id, None)
# if there is a filter, we need to set the children nodes to receive filtered data

self.__set_output_connection_filters(
node_id, current_node_filter, execution=execution
)
Expand All @@ -125,9 +126,14 @@ def __run_node(self, node_id: str, execution: Execution):
node.model_dump(by_alias=True), current_node_filter, execution=execution
)

if input_params is None or (
hasattr(input_params, "empty") and input_params.empty
):
return

# TODO: Remove this condition after adding support for kwargs in the run method for all nodes
if node.kind() == NodeKind.CONNECTOR:
input_params["global_constants"] = execution.global_constants
input_params["context"] = execution.context

output = node.run(**input_params)

Expand Down Expand Up @@ -198,7 +204,7 @@ def execute(
self,
payload_df: pd.DataFrame,
debug_mode: bool = False,
global_constants: typing.Optional[registry.Registry] = None,
context: typing.Optional[registry.Registry] = None,
) -> typing.Union[
pd.DataFrame, typing.Tuple[Execution, typing.Optional[Exception]]
]:
Expand All @@ -207,7 +213,7 @@ def execute(
Args:
payload_df (pd.DataFrame): The payload to be executed.
debug_mode (bool, optional): If True, runs the rule in debug mode and returns the exception, if any. Defaults to False.
global_constants (registry.Registry, optional): Global constants to be used during execution. Defaults to None.
context (registry.Registry, optional): Global constants to be used during execution. Defaults to None.
Raises:
exceptions.ExecutionException: If there is an error during execution.
Expand All @@ -227,7 +233,7 @@ def execute(
execution = Execution.from_payload(
validated_payload=validated_payload,
input_columns=self.input_columns,
global_constants=global_constants,
context=context,
)

for node_id in self.execution_order:
Expand Down

0 comments on commit dea1331

Please sign in to comment.