Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(variable-handling): enhance variable and segment conversion #10483

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/core/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
Expand Down Expand Up @@ -58,4 +59,5 @@
"ArrayStringSegment",
"FileSegment",
"FileVariable",
"ArrayFileVariable",
]
13 changes: 11 additions & 2 deletions api/core/variables/variables.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections.abc import Sequence
from uuid import uuid4

from pydantic import Field

from core.helper import encrypter

from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
Expand All @@ -24,11 +28,12 @@ class Variable(Segment):
"""

id: str = Field(
default="",
description="Unique identity for variable. It's only used by environment variables now.",
default=lambda _: str(uuid4()),
description="Unique identity for variable.",
)
name: str
description: str = Field(default="", description="Description of the variable.")
selector: Sequence[str] = Field(default_factory=list)


class StringVariable(StringSegment, Variable):
Expand Down Expand Up @@ -78,3 +83,7 @@ class NoneVariable(NoneSegment, Variable):

class FileVariable(FileSegment, Variable):
pass


class ArrayFileVariable(ArrayFileSegment, Variable):
pass
9 changes: 6 additions & 3 deletions api/core/workflow/entities/variable_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def add(self, selector: Sequence[str], value: Any, /) -> None:
if len(selector) < 2:
raise ValueError("Invalid selector")

if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
v = value
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
v = variable_factory.build_segment(value)
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)

hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = variable

def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Expand Down
80 changes: 69 additions & 11 deletions api/factories/variable_factory.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,65 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import uuid4

from configs import dify_config
from core.file import File
from core.variables import (
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayNumberVariable,
ArrayObjectSegment,
ArrayObjectVariable,
ArraySegment,
ArrayStringSegment,
ArrayStringVariable,
FileSegment,
FloatSegment,
FloatVariable,
IntegerSegment,
IntegerVariable,
NoneSegment,
ObjectSegment,
ObjectVariable,
SecretVariable,
Segment,
SegmentType,
StringSegment,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
from core.variables.exc import VariableError


class InvalidSelectorError(ValueError):
pass


class UnsupportedSegmentTypeError(Exception):
pass


# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
IntegerSegment: IntegerVariable,
FloatSegment: FloatVariable,
ObjectSegment: ObjectVariable,
FileSegment: FileVariable,
ArrayStringSegment: ArrayStringVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayAnySegment: ArrayAnyVariable,
NoneSegment: NoneVariable,
}


def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
Expand Down Expand Up @@ -96,3 +127,30 @@ def build_segment(value: Any, /) -> Segment:
case _:
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")


def segment_to_variable(
*,
segment: Segment,
selector: Sequence[str],
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
return segment
name = name or selector[-1]
id = id or str(uuid4())

segment_type = type(segment)
if segment_type not in SEGMENT_TO_VARIABLE_MAP:
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")

variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=selector,
)
5 changes: 3 additions & 2 deletions api/tests/unit_tests/core/app/segments/test_segment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from core.helper import encrypter
from core.variables import SecretVariable, StringSegment
from core.variables import SecretVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey

Expand Down Expand Up @@ -54,4 +54,5 @@ def test_convert_variable_to_segment_group():
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert segments_group.value == [StringSegment(value="fake-user-id")]
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"