Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 17, 2024
1 parent c813030 commit 4a705be
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 14 deletions.
70 changes: 60 additions & 10 deletions tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from attrs import define, field

from griptape.artifacts import TextArtifact
from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, TextMessageContent
from griptape.artifacts.action_artifact import ActionArtifact
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
DeltaMessage,
Message,
PromptStack,
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
)
from griptape.drivers import BasePromptDriver
from tests.mocks.mock_tokenizer import MockTokenizer

Expand All @@ -24,16 +34,56 @@ class MockPromptDriver(BasePromptDriver):

def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output

return Message(
content=[TextMessageContent(TextArtifact(output))],
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)
if self.use_native_tools and prompt_stack.tools:
if prompt_stack.messages:
return Message(
content=[TextMessageContent(TextArtifact(f"Answer: {output}"))],
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)
else:
return Message(
content=[
ActionCallMessageContent(
ActionArtifact(
ToolAction(
tag="mock-tag",
name="MockTool",
path="test",
input={"values": {"test": "test-value"}},
)
)
)
],
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)
else:
return Message(
content=[TextMessageContent(TextArtifact(output))],
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)

def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output

yield DeltaMessage(content=TextDeltaMessageContent(output))

yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100))
if self.use_native_tools and prompt_stack.tools:
if prompt_stack.messages:
yield DeltaMessage(content=TextDeltaMessageContent(f"Answer: {output}"))
yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100))
else:
yield DeltaMessage(
content=ActionCallDeltaMessageContent(
tag="mock-tag",
name="MockName",
path="mock_path",
)
)
yield DeltaMessage(
content=ActionCallDeltaMessageContent(
partial_input='{"test-key": "test-value"}',
)
)
else:
yield DeltaMessage(content=TextDeltaMessageContent(output))
15 changes: 15 additions & 0 deletions tests/mocks/mock_text_to_speech_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

from attrs import define, field

from griptape.artifacts import AudioArtifact
from griptape.drivers.text_to_speech.base_text_to_speech_driver import BaseTextToSpeechDriver


@define
class MockTextToSpeechDriver(BaseTextToSpeechDriver):
model: str = field(default="test-model", kw_only=True)
mock_output: str = field(default="mock output", kw_only=True)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
return AudioArtifact(value=self.mock_output, format="mp3")
26 changes: 23 additions & 3 deletions tests/unit/drivers/prompt/test_base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from griptape.artifacts import ErrorArtifact, TextArtifact
from griptape.common import PromptStack
from griptape.common.prompt_stack.messages.message import Message
from griptape.common import Message, PromptStack
from griptape.events import FinishPromptEvent, StartPromptEvent
from griptape.mixins import EventsMixin
from griptape.structures import Pipeline
from griptape.tasks import PromptTask
from griptape.tasks import PromptTask, ToolkitTask
from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver
from tests.mocks.mock_prompt_driver import MockPromptDriver
from tests.mocks.mock_tool.tool import MockTool


class TestBasePromptDriver:
Expand Down Expand Up @@ -46,3 +46,23 @@ def test_run_with_stream(self):
result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[]))
assert isinstance(result, Message)
assert result.value == "mock output"

def test_run_with_tools(self):
driver = MockPromptDriver(max_attempts=1, use_native_tools=True)
pipeline = Pipeline(prompt_driver=driver)

pipeline.add_task(ToolkitTask(tools=[MockTool()]))

output = pipeline.run().output_task.output
assert isinstance(output, TextArtifact)
assert output.value == "mock output"

def test_run_with_tools_and_stream(self):
driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True)
pipeline = Pipeline(prompt_driver=driver)

pipeline.add_task(ToolkitTask(tools=[MockTool()]))

output = pipeline.run().output_task.output
assert isinstance(output, TextArtifact)
assert output.value == "mock output"
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from unittest.mock import Mock

import pytest

from griptape.events.event_listener import EventListener
from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver


class TestBaseTextToSpeechDriver:
@pytest.fixture()
def driver(self):
return MockTextToSpeechDriver()

def test_text_to_audio_publish_events(self, driver):
mock_handler = Mock()
driver.add_event_listener(EventListener(handler=mock_handler))

driver.run_text_to_audio(
["foo", "bar"],
)

call_args = mock_handler.call_args_list

args, _kwargs = call_args[0]
assert args[0].type == "StartTextToSpeechEvent"

args, _kwargs = call_args[1]
assert args[0].type == "FinishTextToSpeechEvent"
8 changes: 8 additions & 0 deletions tests/unit/mixins/test_events_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import Mock

import pytest

from griptape.events import EventListener
from griptape.mixins import EventsMixin
from tests.mocks.mock_event import MockEvent
Expand Down Expand Up @@ -40,6 +42,12 @@ def test_remove_event_listener(self):

assert len(mixin.event_listeners) == 0

def test_remove_unknown_event_listener(self):
mixin = EventsMixin()

with pytest.raises(ValueError):
mixin.remove_event_listener(EventListener())

def test_publish_event(self):
# Given
mock_handler = Mock()
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from unittest.mock import Mock

import pytest

from griptape.artifacts import TextArtifact
from griptape.events.event_listener import EventListener
from griptape.structures import Agent, Workflow
from griptape.tasks import ActionsSubtask
from tests.mocks.mock_embedding_driver import MockEmbeddingDriver
Expand All @@ -12,7 +15,12 @@
class TestBaseTask:
@pytest.fixture()
def task(self):
agent = Agent(prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()])
agent = Agent(
prompt_driver=MockPromptDriver(),
embedding_driver=MockEmbeddingDriver(),
tools=[MockTool()],
event_listeners=[EventListener(handler=Mock())],
)

agent.add_task(MockTask("foobar", max_meta_memory_entries=2))

Expand Down Expand Up @@ -67,3 +75,8 @@ def test_parents_output(self, task):
parent_2.output = None

assert child.parents_output_text == "foobar1\nfoobar3"

def test_execute_publish_events(self, task):
task.execute()

assert task.structure.event_listeners[0].handler.call_count == 2

0 comments on commit 4a705be

Please sign in to comment.