Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(workflow_entry): Support receive File and FileList in single step run. #10947

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def generate(
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
Expand Down
2 changes: 1 addition & 1 deletion api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def generate(
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
Expand Down
14 changes: 7 additions & 7 deletions api/core/app/apps/base_app_generator.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional

from core.app.app_config.entities import VariableEntityType
from core.file import File, FileUploadConfig
from factories import file_factory

if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import VariableEntity


class BaseAppGenerator:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
variables: Sequence["VariableEntity"],
tenant_id: str,
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables
}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
entity_dictionary = {item.variable: item for item in variables}
# Convert single file to File
files_inputs = {
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
Expand All @@ -44,7 +44,7 @@ def _prepare_user_inputs(
file_list_inputs = {
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
Expand Down
2 changes: 1 addition & 1 deletion api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def generate(
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
Expand Down
4 changes: 3 additions & 1 deletion api/core/app/apps/completion/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def generate(
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id
),
query=query,
files=file_objs,
user_id=user.id,
Expand Down
4 changes: 3 additions & 1 deletion api/core/app/apps/workflow/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def generate(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files,
user_id=user.id,
stream=stream,
Expand Down
3 changes: 0 additions & 3 deletions api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
Expand Down Expand Up @@ -160,8 +159,6 @@ def _get_graph_and_variable_pool_of_single_iteration(
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
)

return graph, variable_pool
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/entities/node_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class NodeRunResult(BaseModel):

inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage

Expand Down
97 changes: 34 additions & 63 deletions api/core/workflow/workflow_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from typing import Any, Optional, cast

from configs import dify_config
from core.app.app_config.entities import FileUploadConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File, FileTransferMethod, ImageConfig
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
Expand All @@ -18,9 +17,8 @@
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.llm import LLMNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from factories import file_factory
from models.enums import UserFrom
Expand Down Expand Up @@ -115,7 +113,12 @@ def run(

@classmethod
def single_step_run(
cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict
cls,
*,
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
Expand All @@ -135,13 +138,9 @@ def single_step_run(
raise ValueError("nodes not found in workflow graph")

# fetch node config from node id
node_config = None
for node in nodes:
if node.get("id") == node_id:
node_config = node
break

if not node_config:
try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise ValueError("node id not found in workflow graph")

# Get node class
Expand All @@ -153,11 +152,7 @@ def single_step_run(
raise ValueError(f"Node class not found for node type {node_type}")

# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
variable_pool = VariablePool(environment_variables=workflow.environment_variables)

# init graph
graph = Graph.init(graph_config=workflow.graph_dict)
Expand All @@ -183,28 +178,24 @@ def single_step_run(

try:
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}

cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}

cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
try:
# run node
generator = node_instance.run()

return node_instance, generator
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator

@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
Expand All @@ -231,12 +222,11 @@ def _handle_special_values(value: Any) -> Any:
@classmethod
def mapping_user_inputs_to_variable_pool(
cls,
*,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: dict,
variable_pool: VariablePool,
tenant_id: str,
node_type: NodeType,
node_data: BaseNodeData,
) -> None:
for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable
Expand All @@ -254,40 +244,21 @@ def mapping_user_inputs_to_variable_pool(
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
variable_key_list = cast(list[str], variable_key_list)
variable_key_list = list(variable_key_list)

# get input value
input_value = user_inputs.get(node_variable)
if not input_value:
input_value = user_inputs.get(node_variable_key)

# FIXME: temp fix for image type
if node_type == NodeType.LLM:
new_value = []
if isinstance(input_value, list):
node_data = cast(LLMNodeData, node_data)

detail = node_data.vision.configs.detail if node_data.vision.configs else None

for item in input_value:
if isinstance(item, dict) and "type" in item and item["type"] == "image":
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
mapping = {
"id": item.get("id"),
"transfer_method": transfer_method,
"upload_file_id": item.get("upload_file_id"),
"url": item.get("url"),
}
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
)
new_value.append(file)

if new_value:
input_value = new_value
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
if (
isinstance(input_value, list)
and all(isinstance(item, dict) for item in input_value)
and all("type" in item and "transfer_method" in item for item in input_value)
):
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)

# append variable and value to variable pool
variable_pool.add([variable_node_id] + variable_key_list, input_value)
10 changes: 4 additions & 6 deletions api/factories/file_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,9 @@ def build_from_mapping(
def build_from_mappings(
*,
mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None,
config: FileUploadConfig | None = None,
tenant_id: str,
) -> Sequence[File]:
if not config:
return []

files = [
build_from_mapping(
mapping=mapping,
Expand All @@ -96,13 +93,14 @@ def build_from_mappings(
]

if (
config
# If image config is set.
config.image_config
and config.image_config
# And the number of image files exceeds the maximum limit
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
):
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
if config.number_limits and len(files) > config.number_limits:
if config and config.number_limits and len(files) > config.number_limits:
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")

return files
Expand Down
8 changes: 4 additions & 4 deletions api/services/workflow_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A
new_app = workflow_converter.convert_to_workflow(
app_model=app_model,
account=account,
name=args.get("name"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
name=args.get("name", "Default Name"),
icon_type=args.get("icon_type", "emoji"),
icon=args.get("icon", "🤖"),
icon_background=args.get("icon_background", "#FFEAD5"),
)

return new_app
Expand Down