diff --git a/retrack/engine/base.py b/retrack/engine/base.py index 571a8e2..12649b2 100644 --- a/retrack/engine/base.py +++ b/retrack/engine/base.py @@ -13,12 +13,12 @@ def __init__( payload: pd.DataFrame, states: pd.DataFrame, filters: dict = None, - global_constants: registry.Registry = None, + context: registry.Registry = None, ): self.payload = payload self.states = states self.filters = filters or {} - self.global_constants = global_constants + self.context = context def set_state_data( self, column: str, value: typing.Any, filter_by: typing.Any = None @@ -53,7 +53,7 @@ def from_payload( cls, validated_payload: pd.DataFrame, input_columns: dict, - global_constants: registry.Registry = None, + context: registry.Registry = None, ): state_df = pd.DataFrame([]) for node_id, input_name in input_columns.items(): @@ -65,7 +65,7 @@ def from_payload( return cls( payload=validated_payload, states=state_df, - global_constants=global_constants, + context=context, ) @property diff --git a/retrack/engine/executor.py b/retrack/engine/executor.py index 0c67987..1f83538 100644 --- a/retrack/engine/executor.py +++ b/retrack/engine/executor.py @@ -83,6 +83,7 @@ def __set_output_connection_filters( execution: Execution, connector_filter=None, ): + """If there is a filter, we need to set the children nodes to receive filtered data""" if filter_value is None: return @@ -111,7 +112,7 @@ def __get_input_params( def __run_node(self, node_id: str, execution: Execution): current_node_filter = execution.filters.get(node_id, None) - # if there is a filter, we need to set the children nodes to receive filtered data + self.__set_output_connection_filters( node_id, current_node_filter, execution=execution ) @@ -125,9 +126,14 @@ def __run_node(self, node_id: str, execution: Execution): node.model_dump(by_alias=True), current_node_filter, execution=execution ) + if input_params is None or ( + hasattr(input_params, "empty") and input_params.empty + ): + return + # TODO: Remove this condition after adding support for kwargs in the run method for all nodes if node.kind() == NodeKind.CONNECTOR: - input_params["global_constants"] = execution.global_constants + input_params["context"] = execution.context output = node.run(**input_params) @@ -198,7 +204,7 @@ def execute( self, payload_df: pd.DataFrame, debug_mode: bool = False, - global_constants: typing.Optional[registry.Registry] = None, + context: typing.Optional[registry.Registry] = None, ) -> typing.Union[ pd.DataFrame, typing.Tuple[Execution, typing.Optional[Exception]] ]: @@ -207,7 +213,7 @@ def execute( Args: payload_df (pd.DataFrame): The payload to be executed. debug_mode (bool, optional): If True, runs the rule in debug mode and returns the exception, if any. Defaults to False. - global_constants (registry.Registry, optional): Global constants to be used during execution. Defaults to None. + context (registry.Registry, optional): Global constants to be used during execution. Defaults to None. Raises: exceptions.ExecutionException: If there is an error during execution. @@ -227,7 +233,7 @@ def execute( execution = Execution.from_payload( validated_payload=validated_payload, input_columns=self.input_columns, - global_constants=global_constants, + context=context, ) for node_id in self.execution_order: