From eeecdb0de16f9aecdc33833d755ceb6ffbc4eae9 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 22 Oct 2024 02:08:43 +0800 Subject: [PATCH] refactor(variables): replace deprecated 'get_any' with 'get' method - Removed deprecated `get_any` method in favor of the `get` method for variable retrieval across multiple nodes. - Improved error handling for variable retrieval by adding checks for variable existence and type validation. - Introduced specific error messages for missing or invalid variable types to enhance debugging capabilities. --- api/core/workflow/entities/variable_pool.py | 21 ------ api/core/workflow/nodes/code/code_node.py | 13 ++-- .../nodes/iteration/iteration_node.py | 35 +++++++--- .../knowledge_retrieval_node.py | 11 +++- api/core/workflow/nodes/llm/node.py | 65 ++++++++++--------- .../template_transform_node.py | 9 ++- .../variable_aggregator_node.py | 14 ++-- .../workflow/nodes/test_code.py | 4 ++ 8 files changed, 95 insertions(+), 77 deletions(-) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 2a7c7234ea6c5..f8968990d48f9 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -4,7 +4,6 @@ from typing import Any, Union from pydantic import BaseModel, Field -from typing_extensions import deprecated from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable @@ -133,26 +132,6 @@ def get(self, selector: Sequence[str], /) -> Segment | None: return value - @deprecated("This method is deprecated, use `get` instead.") - def get_any(self, selector: Sequence[str], /) -> Any | None: - """ - Retrieves the value from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): The selector used to identify the variable. - - Returns: - Any: The value associated with the given selector. - - Raises: - ValueError: If the selector is invalid. - """ - if len(selector) < 2: - raise ValueError("Invalid selector") - hash_key = hash(tuple(selector[1:])) - value = self.variable_dictionary[selector[0]].get(hash_key) - return value.to_object() if value else None - def remove(self, selector: Sequence[str], /): """ Remove variables from the variable pool based on the given selector. diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index dd533ffc4c665..9d7d9027c3618 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -41,10 +41,15 @@ def _run(self) -> NodeRunResult: # Get variables variables = {} for variable_selector in self.node_data.variables: - variable = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) - - variables[variable] = value + variable_name = variable_selector.variable + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=f"Variable `{variable_selector.value_selector}` not found", + ) + variables[variable_name] = variable.to_object() # Run code try: result = CodeExecutor.execute_workflow_code_template( diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 7503c13ce817a..af79da9215c5d 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -5,6 +5,7 @@ from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder +from core.variables import IntegerSegment from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, @@ -147,9 +148,16 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: if NodeRunMetadataKey.ITERATION_ID not in metadata: metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( - [self.node_id, "index"] - ) + index_variable = variable_pool.get([self.node_id, "index"]) + if not isinstance(index_variable, IntegerSegment): + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Invalid index variable type: {type(index_variable)}", + ) + ) + return + metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value event.route_node_state.node_run_result.metadata = metadata yield event @@ -181,7 +189,16 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: yield event # append to iteration output variable list - current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + current_iteration_output_variable = variable_pool.get(self.node_data.output_selector) + if current_iteration_output_variable is None: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Iteration output variable {self.node_data.output_selector} not found", + ) + ) + return + current_iteration_output = current_iteration_output_variable.to_object() outputs.append(current_iteration_output) # remove all nodes outputs from variable pool @@ -189,11 +206,11 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: variable_pool.remove([node_id]) # move to next iteration - current_index = variable_pool.get([self.node_id, "index"]) - if current_index is None: + current_index_variable = variable_pool.get([self.node_id, "index"]) + if not isinstance(current_index_variable, IntegerSegment): raise ValueError(f"iteration {self.node_id} current index not found") - next_index = int(current_index.to_object()) + 1 + next_index = current_index_variable.value + 1 variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): @@ -205,9 +222,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output) - if current_iteration_output - else None, + pre_iteration_output=jsonable_encoder(current_iteration_output), ) yield IterationRunSucceededEvent( diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b286f34d7f5e1..2a5795a3ed6c8 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,6 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.variables import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType @@ -39,8 +40,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): def _run(self) -> NodeRunResult: # extract variables - variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector) - query = variable + variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) + if not isinstance(variable, StringSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value variables = {"query": query} if not query: return NodeRunResult( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 4179d34f9a1f0..94aa8c5eab905 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -22,7 +22,15 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment +from core.variables import ( + ArrayAnySegment, + ArrayFileSegment, + ArraySegment, + FileSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.enums import SystemVariableKey @@ -263,50 +271,44 @@ def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: return variables for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) + variable_name = variable_selector.variable + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: + raise ValueError(f"Variable {variable_selector.variable} not found") - def parse_dict(d: dict) -> str: + def parse_dict(input_dict: Mapping[str, Any]) -> str: """ Parse dict into string """ # check if it's a context structure - if "metadata" in d and "_source" in d["metadata"] and "content" in d: - return d["content"] + if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: + return input_dict["content"] # else, parse the dict try: - return json.dumps(d, ensure_ascii=False) + return json.dumps(input_dict, ensure_ascii=False) except Exception: - return str(d) + return str(input_dict) - if isinstance(value, str): - value = value - elif isinstance(value, list): + if isinstance(variable, ArraySegment): result = "" - for item in value: + for item in variable.value: if isinstance(item, dict): result += parse_dict(item) - elif isinstance(item, str): - result += item - elif isinstance(item, int | float): - result += str(item) else: result += str(item) result += "\n" value = result.strip() - elif isinstance(value, dict): - value = parse_dict(value) - elif isinstance(value, int | float): - value = str(value) + elif isinstance(variable, ObjectSegment): + value = parse_dict(variable.value) else: - value = str(value) + value = variable.text - variables[variable] = value + variables[variable_name] = value return variables - def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]: + def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: inputs = {} prompt_template = node_data.prompt_template @@ -363,14 +365,14 @@ def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.variable_selector: return - context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector) - if context_value: - if isinstance(context_value, str): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) - elif isinstance(context_value, list): + context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) + if context_value_variable: + if isinstance(context_value_variable, StringSegment): + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource = [] - for item in context_value: + for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" else: @@ -484,11 +486,12 @@ def _fetch_memory( return None # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get_any( + conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID.value] ) - if conversation_id is None: + if not isinstance(conversation_id_variable, StringSegment): return None + conversation_id = conversation_id_variable.value # get conversation conversation = ( diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index b85271ddc699b..0ee66784c5b86 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -33,8 +33,13 @@ def _run(self) -> NodeRunResult: variables = {} for variable_selector in self.node_data.variables: variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) - variables[variable_name] = value + value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if value is None: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Variable {variable_name} not found in variable pool", + ) + variables[variable_name] = value.to_object() # Run code try: result = CodeExecutor.execute_workflow_code_template( diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 05477e2a9006b..031a7b8309554 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -19,27 +19,27 @@ def _run(self) -> NodeRunResult: if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: for selector in self.node_data.variables: - variable = self.graph_runtime_state.variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = {"output": variable} + outputs = {"output": variable.to_object()} - inputs = {".".join(selector[1:]): variable} + inputs = {".".join(selector[1:]): variable.to_object()} break else: for group in self.node_data.advanced_settings.groups: for selector in group.variables: - variable = self.graph_runtime_state.variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = {"output": variable} - inputs[".".join(selector[1:])] = variable + outputs[group.group_name] = {"output": variable.to_object()} + inputs[".".join(selector[1:])] = variable.to_object() break return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index fd0f25cf04bc0..4de985ae7c9de 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -102,6 +102,8 @@ def main(args1: int, args2: int) -> dict: } node = init_code_node(code_config) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2) # execute node result = node._run() @@ -146,6 +148,8 @@ def main(args1: int, args2: int) -> dict: } node = init_code_node(code_config) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2) # execute node result = node._run()