Skip to content

Commit

Permalink
Fix modifying config's structure
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 22, 2024
1 parent 0df4872 commit 891715c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 18 deletions.
38 changes: 28 additions & 10 deletions griptape/config/base_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.config import BaseConfig
from griptape.events import EventListener
from griptape.mixins.event_publisher_mixin import EventPublisherMixin
from griptape.utils import dict_merge

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,25 +39,42 @@ class BaseStructureConfig(BaseConfig, ABC):
text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True})
audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True})

_structure: Optional[Structure] = field(default=None, kw_only=True, alias="structure")
_structure: Structure = field(default=None, kw_only=True, alias="structure")
_event_listener: Optional[EventListener] = field(default=None, kw_only=True, alias="event_listener")

@property
def drivers(self) -> list:
return [
self.prompt_driver,
self.image_generation_driver,
self.image_query_driver,
self.embedding_driver,
self.vector_store_driver,
self.conversation_memory_driver,
self.text_to_speech_driver,
self.audio_transcription_driver,
]

@property
def structure(self) -> Optional[Structure]:
return self._structure

@structure.setter
def structure(self, structure: Structure) -> None:
self._structure = structure
if structure != self.structure:
event_publisher_drivers = [
driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin)
]

event_listener = EventListener(self.structure.publish_event)
for driver in event_publisher_drivers:
if self._event_listener is not None:
driver.remove_event_listener(self._event_listener)

self.prompt_driver.add_event_listener(event_listener)
self.image_generation_driver.add_event_listener(event_listener)
self.image_query_driver.add_event_listener(event_listener)
self.embedding_driver.add_event_listener(event_listener)
self.vector_store_driver.add_event_listener(event_listener)
if self.conversation_memory_driver is not None:
self.conversation_memory_driver.add_event_listener(event_listener)
self._event_listener = EventListener(structure.publish_event)
for driver in event_publisher_drivers:
driver.add_event_listener(self._event_listener)

self._structure = structure

def merge_config(self, config: dict) -> BaseStructureConfig:
base_config = self.to_dict()
Expand Down
2 changes: 0 additions & 2 deletions griptape/mixins/event_publisher_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def add_event_listener(self, event_listener: EventListener) -> EventListener:
def remove_event_listener(self, event_listener: EventListener) -> None:
if event_listener in self.event_listeners:
self.event_listeners.remove(event_listener)
else:
raise ValueError("Event Listener not found.")

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self.event_listeners:
Expand Down
9 changes: 7 additions & 2 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact
from griptape.common import observable
from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig
from griptape.drivers import BaseEmbeddingDriver, BasePromptDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver
from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver
from griptape.drivers import (
BaseEmbeddingDriver,
BasePromptDriver,
LocalVectorStoreDriver,
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
)
from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine
from griptape.engines.rag import RagEngine
from griptape.engines.rag.modules import (
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/config/test_structure_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from griptape.config import StructureConfig
from griptape.structures import Agent


class TestStructureConfig:
Expand Down Expand Up @@ -60,3 +61,37 @@ def test_dot_update(self, config):
config.prompt_driver.max_tokens = 10

assert config.prompt_driver.max_tokens == 10

def test_drivers(self, config):
assert config.drivers == [
config.prompt_driver,
config.image_generation_driver,
config.image_query_driver,
config.embedding_driver,
config.vector_store_driver,
config.conversation_memory_driver,
config.text_to_speech_driver,
config.audio_transcription_driver,
]

def test_structure(self, config):
structure_1 = Agent(
config=config,
)

assert config.structure == structure_1
assert config._event_listener is not None
for driver in config.drivers:
if driver is not None:
assert config._event_listener in driver.event_listeners
assert len(driver.event_listeners) == 1

structure_2 = Agent(
config=config,
)
assert config.structure == structure_2
assert config._event_listener is not None
for driver in config.drivers:
if driver is not None:
assert config._event_listener in driver.event_listeners
assert len(driver.event_listeners) == 1
5 changes: 1 addition & 4 deletions tests/unit/mixins/test_events_mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from unittest.mock import Mock

import pytest

from griptape.events import EventListener
from griptape.mixins import EventPublisherMixin
from tests.mocks.mock_event import MockEvent
Expand Down Expand Up @@ -45,8 +43,7 @@ def test_remove_event_listener(self):
def test_remove_unknown_event_listener(self):
mixin = EventPublisherMixin()

with pytest.raises(ValueError):
mixin.remove_event_listener(EventListener())
mixin.remove_event_listener(EventListener())

def test_publish_event(self):
# Given
Expand Down

0 comments on commit 891715c

Please sign in to comment.