Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EventsMixin #984

Merged
merged 6 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions griptape/config/base_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from attrs import define, field

from griptape.config import BaseConfig
from griptape.events import EventListener
from griptape.mixins.event_publisher_mixin import EventPublisherMixin
from griptape.utils import dict_merge

if TYPE_CHECKING:
Expand All @@ -19,6 +21,7 @@
BaseTextToSpeechDriver,
BaseVectorStoreDriver,
)
from griptape.structures import Structure


@define
Expand All @@ -36,6 +39,43 @@ class BaseStructureConfig(BaseConfig, ABC):
text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True})
audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True})

_structure: Structure = field(default=None, kw_only=True, alias="structure")
_event_listener: Optional[EventListener] = field(default=None, kw_only=True, alias="event_listener")

@property
def drivers(self) -> list:
return [
self.prompt_driver,
self.image_generation_driver,
self.image_query_driver,
self.embedding_driver,
self.vector_store_driver,
self.conversation_memory_driver,
self.text_to_speech_driver,
self.audio_transcription_driver,
]

@property
def structure(self) -> Optional[Structure]:
return self._structure

@structure.setter
def structure(self, structure: Structure) -> None:
if structure != self.structure:
event_publisher_drivers = [
driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin)
]

for driver in event_publisher_drivers:
if self._event_listener is not None:
driver.remove_event_listener(self._event_listener)

self._event_listener = EventListener(structure.publish_event)
for driver in event_publisher_drivers:
driver.add_event_listener(self._event_listener)

self._structure = structure

def merge_config(self, config: dict) -> BaseStructureConfig:
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
base_config = self.to_dict()
merged_config = dict_merge(base_config, config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,21 @@
from attrs import define, field

from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.structures import Structure


@define
class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC):
model: str = field(kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)

def before_run(self) -> None:
if self.structure:
self.structure.publish_event(StartAudioTranscriptionEvent())
self.publish_event(StartAudioTranscriptionEvent())

def after_run(self) -> None:
if self.structure:
self.structure.publish_event(FinishAudioTranscriptionEvent())
self.publish_event(FinishAudioTranscriptionEvent())

def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
for attempt in self.retrying():
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/embedding/base_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from attrs import define, field

from griptape.chunkers import BaseChunker, TextChunker
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import TextArtifact
from griptape.tokenizers import BaseTokenizer


@define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC):
"""Base Embedding Driver.

Attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,21 @@
from attrs import define, field

from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import ImageArtifact
from griptape.structures import Structure


@define
class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC):
model: str = field(kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)

def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None:
if self.structure:
self.structure.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))
self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

def after_run(self) -> None:
if self.structure:
self.structure.publish_event(FinishImageGenerationEvent())
self.publish_event(FinishImageGenerationEvent())

def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
for attempt in self.retrying():
Expand Down
18 changes: 7 additions & 11 deletions griptape/drivers/image_query/base_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

from griptape.events import FinishImageQueryEvent, StartImageQueryEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.structures import Structure


@define
class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
structure: Optional[Structure] = field(default=None, kw_only=True)
class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC):
max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True})

def before_run(self, query: str, images: list[ImageArtifact]) -> None:
if self.structure:
self.structure.publish_event(
StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]),
)
self.publish_event(
StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]),
)

def after_run(self, result: str) -> None:
if self.structure:
self.structure.publish_event(FinishImageQueryEvent(result=result))
self.publish_event(FinishImageQueryEvent(result=result))

def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
for attempt in self.retrying():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

from griptape.mixins import SerializableMixin
from griptape.mixins import EventPublisherMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory


class BaseConversationMemoryDriver(SerializableMixin, ABC):
class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC):
@abstractmethod
def store(self, memory: BaseConversationMemory) -> None: ...

Expand Down
33 changes: 14 additions & 19 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,21 @@
observable,
)
from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.structures import Structure
from griptape.tokenizers import BaseTokenizer


@define(kw_only=True)
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC):
"""Base class for the Prompt Drivers.

Attributes:
temperature: The temperature to use for the completion.
max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
structure: An optional `Structure` to publish events to.
prompt_stack_to_string: A function that converts a `PromptStack` to a string.
ignored_exception_types: A tuple of exception types to ignore.
model: The model name.
Expand All @@ -44,27 +42,24 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):

temperature: float = field(default=0.1, metadata={"serializable": True})
max_tokens: Optional[int] = field(default=None, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None)
ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (ImportError, ValueError)))
model: str = field(metadata={"serializable": True})
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
if self.structure:
self.structure.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))
self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

def after_run(self, result: Message) -> None:
if self.structure:
self.structure.publish_event(
FinishPromptEvent(
model=self.model,
result=result.value,
input_token_count=result.usage.input_tokens,
output_token_count=result.usage.output_tokens,
),
)
self.publish_event(
FinishPromptEvent(
model=self.model,
result=result.value,
input_token_count=result.usage.input_tokens,
output_token_count=result.usage.output_tokens,
),
)

@observable(tags=["PromptDriver.run()"])
def run(self, prompt_stack: PromptStack) -> Message:
Expand Down Expand Up @@ -133,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
else:
delta_contents[content.index] = [content]
if isinstance(content, TextDeltaMessageContent):
self.structure.publish_event(CompletionChunkEvent(token=content.text))
self.publish_event(CompletionChunkEvent(token=content.text))
elif isinstance(content, ActionCallDeltaMessageContent):
if content.tag is not None and content.name is not None and content.path is not None:
self.structure.publish_event(CompletionChunkEvent(token=str(content)))
self.publish_event(CompletionChunkEvent(token=str(content)))
elif content.partial_input is not None:
self.structure.publish_event(CompletionChunkEvent(token=content.partial_input))
self.publish_event(CompletionChunkEvent(token=content.partial_input))

# Build a complete content from the content deltas
result = self.__build_message(list(delta_contents.values()), usage)
Expand Down
14 changes: 5 additions & 9 deletions griptape/drivers/text_to_speech/base_text_to_speech_driver.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent
from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.structures import Structure


@define
class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC):
model: str = field(kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)

def before_run(self, prompts: list[str]) -> None:
if self.structure:
self.structure.publish_event(StartTextToSpeechEvent(prompts=prompts))
self.publish_event(StartTextToSpeechEvent(prompts=prompts))

def after_run(self) -> None:
if self.structure:
self.structure.publish_event(FinishTextToSpeechEvent())
self.publish_event(FinishTextToSpeechEvent())

def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
for attempt in self.retrying():
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

from griptape import utils
from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.mixins import SerializableMixin
from griptape.mixins import EventPublisherMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.drivers import BaseEmbeddingDriver


@define
class BaseVectorStoreDriver(SerializableMixin, ABC):
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC):
DEFAULT_QUERY_COUNT = 5

@dataclass
Expand Down
2 changes: 2 additions & 0 deletions griptape/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .rule_mixin import RuleMixin
from .serializable_mixin import SerializableMixin
from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin
from .event_publisher_mixin import EventPublisherMixin

__all__ = [
"ActivityMixin",
Expand All @@ -12,4 +13,5 @@
"RuleMixin",
"BlobArtifactFileOutputMixin",
"SerializableMixin",
"EventPublisherMixin",
]
34 changes: 34 additions & 0 deletions griptape/mixins/event_publisher_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import define, field

if TYPE_CHECKING:
from griptape.events import BaseEvent, EventListener


@define
class EventPublisherMixin:
event_listeners: list[EventListener] = field(factory=list, kw_only=True)

def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]:
return [self.add_event_listener(event_listener) for event_listener in event_listeners]

def remove_event_listeners(self, event_listeners: list[EventListener]) -> None:
for event_listener in event_listeners:
self.remove_event_listener(event_listener)

def add_event_listener(self, event_listener: EventListener) -> EventListener:
if event_listener not in self.event_listeners:
self.event_listeners.append(event_listener)

return event_listener

def remove_event_listener(self, event_listener: EventListener) -> None:
if event_listener in self.event_listeners:
self.event_listeners.remove(event_listener)

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self.event_listeners:
event_listener.publish_event(event, flush=flush)
Loading
Loading