From 891715c94301b9fa9c180fad8f2032a1486b6da8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 22 Jul 2024 13:41:33 -0700 Subject: [PATCH] Fix modifying config's structure --- griptape/config/base_structure_config.py | 38 ++++++++++++++++------ griptape/mixins/event_publisher_mixin.py | 2 -- griptape/structures/structure.py | 9 +++-- tests/unit/config/test_structure_config.py | 35 ++++++++++++++++++++ tests/unit/mixins/test_events_mixin.py | 5 +-- 5 files changed, 71 insertions(+), 18 deletions(-) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index 9961018a79..31949cd2f3 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -7,6 +7,7 @@ from griptape.config import BaseConfig from griptape.events import EventListener +from griptape.mixins.event_publisher_mixin import EventPublisherMixin from griptape.utils import dict_merge if TYPE_CHECKING: @@ -38,7 +39,21 @@ class BaseStructureConfig(BaseConfig, ABC): text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) - _structure: Optional[Structure] = field(default=None, kw_only=True, alias="structure") + _structure: Structure = field(default=None, kw_only=True, alias="structure") + _event_listener: Optional[EventListener] = field(default=None, kw_only=True, alias="event_listener") + + @property + def drivers(self) -> list: + return [ + self.prompt_driver, + self.image_generation_driver, + self.image_query_driver, + self.embedding_driver, + self.vector_store_driver, + self.conversation_memory_driver, + self.text_to_speech_driver, + self.audio_transcription_driver, + ] @property def structure(self) -> Optional[Structure]: @@ -46,17 +61,20 @@ def structure(self) -> Optional[Structure]: @structure.setter def structure(self, structure: Structure) -> None: - self._structure = structure + if structure != self.structure: + event_publisher_drivers = [ + driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) + ] - event_listener = EventListener(self.structure.publish_event) + for driver in event_publisher_drivers: + if self._event_listener is not None: + driver.remove_event_listener(self._event_listener) - self.prompt_driver.add_event_listener(event_listener) - self.image_generation_driver.add_event_listener(event_listener) - self.image_query_driver.add_event_listener(event_listener) - self.embedding_driver.add_event_listener(event_listener) - self.vector_store_driver.add_event_listener(event_listener) - if self.conversation_memory_driver is not None: - self.conversation_memory_driver.add_event_listener(event_listener) + self._event_listener = EventListener(structure.publish_event) + for driver in event_publisher_drivers: + driver.add_event_listener(self._event_listener) + + self._structure = structure def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/mixins/event_publisher_mixin.py index c4bd99fe2a..67a302ed61 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/mixins/event_publisher_mixin.py @@ -28,8 +28,6 @@ def add_event_listener(self, event_listener: EventListener) -> EventListener: def remove_event_listener(self, event_listener: EventListener) -> None: if event_listener in self.event_listeners: self.event_listeners.remove(event_listener) - else: - raise ValueError("Event Listener not found.") def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 310841fa30..765910f5c8 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -12,8 +12,13 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig -from griptape.drivers import BaseEmbeddingDriver, BasePromptDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver -from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver +from griptape.drivers import ( + BaseEmbeddingDriver, + BasePromptDriver, + LocalVectorStoreDriver, + OpenAiChatPromptDriver, + OpenAiEmbeddingDriver, +) from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 96a68628f2..b9e3477e45 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,6 +1,7 @@ import pytest from griptape.config import StructureConfig +from griptape.structures import Agent class TestStructureConfig: @@ -60,3 +61,37 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 + + def test_drivers(self, config): + assert config.drivers == [ + config.prompt_driver, + config.image_generation_driver, + config.image_query_driver, + config.embedding_driver, + config.vector_store_driver, + config.conversation_memory_driver, + config.text_to_speech_driver, + config.audio_transcription_driver, + ] + + def test_structure(self, config): + structure_1 = Agent( + config=config, + ) + + assert config.structure == structure_1 + assert config._event_listener is not None + for driver in config.drivers: + if driver is not None: + assert config._event_listener in driver.event_listeners + assert len(driver.event_listeners) == 1 + + structure_2 = Agent( + config=config, + ) + assert config.structure == structure_2 + assert config._event_listener is not None + for driver in config.drivers: + if driver is not None: + assert config._event_listener in driver.event_listeners + assert len(driver.event_listeners) == 1 diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py index b934553f72..99f5541bad 100644 --- a/tests/unit/mixins/test_events_mixin.py +++ b/tests/unit/mixins/test_events_mixin.py @@ -1,7 +1,5 @@ from unittest.mock import Mock -import pytest - from griptape.events import EventListener from griptape.mixins import EventPublisherMixin from tests.mocks.mock_event import MockEvent @@ -45,8 +43,7 @@ def test_remove_event_listener(self): def test_remove_unknown_event_listener(self): mixin = EventPublisherMixin() - with pytest.raises(ValueError): - mixin.remove_event_listener(EventListener()) + mixin.remove_event_listener(EventListener()) def test_publish_event(self): # Given