Skip to content

Commit

Permalink
feat(variable-handling): enhance variable and segment conversion (lan…
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 authored and jiangzhijie committed Nov 14, 2024
1 parent f4ebb56 commit 34e4c50
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 18 deletions.
2 changes: 2 additions & 0 deletions api/core/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
Expand Down Expand Up @@ -58,4 +59,5 @@
"ArrayStringSegment",
"FileSegment",
"FileVariable",
"ArrayFileVariable",
]
13 changes: 11 additions & 2 deletions api/core/variables/variables.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections.abc import Sequence
from uuid import uuid4

from pydantic import Field

from core.helper import encrypter

from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
Expand All @@ -24,11 +28,12 @@ class Variable(Segment):
"""

id: str = Field(
default="",
description="Unique identity for variable. It's only used by environment variables now.",
default=lambda _: str(uuid4()),
description="Unique identity for variable.",
)
name: str
description: str = Field(default="", description="Description of the variable.")
selector: Sequence[str] = Field(default_factory=list)


class StringVariable(StringSegment, Variable):
Expand Down Expand Up @@ -78,3 +83,7 @@ class NoneVariable(NoneSegment, Variable):

class FileVariable(FileSegment, Variable):
pass


class ArrayFileVariable(ArrayFileSegment, Variable):
pass
9 changes: 6 additions & 3 deletions api/core/workflow/entities/variable_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def add(self, selector: Sequence[str], value: Any, /) -> None:
if len(selector) < 2:
raise ValueError("Invalid selector")

if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
v = value
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
v = variable_factory.build_segment(value)
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)

hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = variable

def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Expand Down
80 changes: 69 additions & 11 deletions api/factories/variable_factory.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,65 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import uuid4

from configs import dify_config
from core.file import File
from core.variables import (
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayNumberVariable,
ArrayObjectSegment,
ArrayObjectVariable,
ArraySegment,
ArrayStringSegment,
ArrayStringVariable,
FileSegment,
FloatSegment,
FloatVariable,
IntegerSegment,
IntegerVariable,
NoneSegment,
ObjectSegment,
ObjectVariable,
SecretVariable,
Segment,
SegmentType,
StringSegment,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
from core.variables.exc import VariableError


class InvalidSelectorError(ValueError):
pass


class UnsupportedSegmentTypeError(Exception):
pass


# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
IntegerSegment: IntegerVariable,
FloatSegment: FloatVariable,
ObjectSegment: ObjectVariable,
FileSegment: FileVariable,
ArrayStringSegment: ArrayStringVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayFileSegment: ArrayFileVariable,
ArrayAnySegment: ArrayAnyVariable,
NoneSegment: NoneVariable,
}


def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
Expand Down Expand Up @@ -96,3 +127,30 @@ def build_segment(value: Any, /) -> Segment:
case _:
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")


def segment_to_variable(
*,
segment: Segment,
selector: Sequence[str],
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
return segment
name = name or selector[-1]
id = id or str(uuid4())

segment_type = type(segment)
if segment_type not in SEGMENT_TO_VARIABLE_MAP:
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")

variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=selector,
)
5 changes: 3 additions & 2 deletions api/tests/unit_tests/core/app/segments/test_segment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from core.helper import encrypter
from core.variables import SecretVariable, StringSegment
from core.variables import SecretVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey

Expand Down Expand Up @@ -54,4 +54,5 @@ def test_convert_variable_to_segment_group():
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert segments_group.value == [StringSegment(value="fake-user-id")]
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"

0 comments on commit 34e4c50

Please sign in to comment.