Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench committed Feb 8, 2024
1 parent b31b6cb commit b4ce30d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `vector_path` field on `MongoDbAtlasVectorStoreDriver`.
- `LeonardoImageGenerationDriver` supports image to image generation.
- `OpenAiVisionImageQueryDriver` to support queries on images using OpenAI's Vision model.
- `ImageQueryClient` allowing an agent to make queries on images on disk or in memory.
- Image Query engine and task.
- `ImageQueryClient` allowing an Agent to make queries on images on disk or in Task Memory.
- `ImageQueryTask` and `ImageQueryEngine`.

### Fixed
- `BedrockStableDiffusionImageGenerationModelDriver` request parameters for SDXLv1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpenAiVisionImageQueryDriver(BaseImageQueryDriver):
model: str = field(default="gpt-4-vision-preview", kw_only=True, metadata={"serializable": True})
api_type: str = field(default=openai.api_type, kw_only=True, metadata={"serializable": True})
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: str = field(default=None, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True})
Expand Down Expand Up @@ -53,4 +53,7 @@ def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:

response = self.client.chat.completions.create(**params)

if len(response.choices) != 1:
raise Exception("Image query responses with more than one choice are not supported yet.")

return TextArtifact(response.choices[0].message.content)
6 changes: 3 additions & 3 deletions griptape/events/base_task_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BaseTaskEvent(BaseEvent, ABC):
task_parent_ids: list[str] = field(kw_only=True, metadata={"serializable": True})
task_child_ids: list[str] = field(kw_only=True, metadata={"serializable": True})

task_input: Union[
BaseArtifact, BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]]
] = field(kw_only=True, metadata={"serializable": True})
task_input: Union[BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]]] = field(
kw_only=True, metadata={"serializable": True}
)
task_output: Optional[BaseArtifact] = field(kw_only=True, metadata={"serializable": True})
17 changes: 13 additions & 4 deletions griptape/tasks/image_query_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@

@define
class ImageQueryTask(BaseTask):
"""A task that executes a natural language query on one or more input images. Accepts a text prompt and a list of
images as input in one of the following formats:
- tuple of (template string, list[ImageArtifact])
- tuple of (TextArtifact, list[ImageArtifact])
- Callable that returns a tuple of (TextArtifact, list[ImageArtifact])
Attributes:
image_query_engine: The engine used to execute the query.
"""

image_query_engine: ImageQueryEngine = field(kw_only=True)
_input: tuple[str, list[ImageArtifact]] | tuple[TextArtifact, list[ImageArtifact]] | Callable[
[BaseTask], tuple[TextArtifact, list[ImageArtifact]]
Expand All @@ -30,8 +40,8 @@ def input(self) -> tuple[TextArtifact, list[ImageArtifact]]:
return self._input(self)
else:
raise ValueError(
"Input must be a tuple of a text artifact and a list of image artifacts or a callable that "
"returns a tuple of a text artifact and a list of image artifacts."
"Input must be a tuple of a TextArtifact and a list of ImageArtifacts or a callable that "
"returns a tuple of a TextArtifact and a list of ImageArtifacts."
)

@input.setter
Expand All @@ -43,8 +53,7 @@ def input(
self._input = value

def run(self) -> TextArtifact:
query = self.input[0]
image_artifacts = self.input[1]
query, image_artifacts = self.input

response = self.image_query_engine.run(query.value, image_artifacts)

Expand Down
13 changes: 3 additions & 10 deletions griptape/tools/image_query_client/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,8 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact:
"query",
description="A detailed question to be answered using the contents of the provided images.",
): str,
Literal("image_artifact_references", description="Image artifact memory references."): [
{
Literal(
"image_artifact_namespace", description="The namespace of the image artifact in memory."
): str,
Literal(
"image_artifact_name", description="The name of the image artifact in memory."
): str,
}
Literal("image_artifacts", description="Image artifact memory references."): [
{"namespace": str, "name": str}
],
"memory_name": str,
}
Expand All @@ -69,7 +62,7 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact:
)
def query_images_from_memory(self, params: dict[str, Any]) -> TextArtifact | ErrorArtifact:
query = params["values"]["query"]
image_artifact_references = params["values"]["image_artifact_references"]
image_artifact_references = params["values"]["image_artifacts"]
memory = self.find_input_memory(params["values"]["memory_name"])

if memory is None:
Expand Down

0 comments on commit b4ce30d

Please sign in to comment.