Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/futures context #1512

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors.

### Fixed

- Occasional crash during `FuturesExecutorMixin` cleanup.

### Deprecated

- `FuturesExecutorMixin.futures_executor`. Use `FuturesExecutorMixin.create_futures_executor` instead.

## [1.1.1] - 2025-01-03

### Fixed
Expand Down
18 changes: 10 additions & 8 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ def batch(self) -> list[dict]:
def publish_event(self, event: BaseEvent | dict) -> None:
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []
else:
self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)
with self.create_futures_executor() as futures_executor:
if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []
else:
futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)

def flush_events(self) -> None:
if self.batch:
self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
with self.create_futures_executor() as futures_executor:
futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []

@abstractmethod
Expand Down
45 changes: 23 additions & 22 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,31 @@ def upsert_text_artifacts(
meta: Optional[dict] = None,
**kwargs,
) -> list[str] | dict[str, list[str]]:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
self.futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
for a in artifacts
],
)
else:
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
self.futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
with self.create_futures_executor() as futures_executor:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
for a in artifacts
],
)
else:
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
)
)
)

return utils.execute_futures_list_dict(futures_dict)
return utils.execute_futures_list_dict(futures_dict)

def upsert_text_artifact(
self,
Expand Down
7 changes: 4 additions & 3 deletions griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def modules(self) -> list[BaseRagModule]:
def run(self, context: RagContext) -> RagContext:
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

results = utils.execute_futures_list(
[self.futures_executor.submit(with_contextvars(r.run), context) for r in self.response_modules]
)
with self.create_futures_executor() as futures_executor:
results = utils.execute_futures_list(
[futures_executor.submit(with_contextvars(r.run), context) for r in self.response_modules]
)

context.outputs = results

Expand Down
7 changes: 4 additions & 3 deletions griptape/engines/rag/stages/retrieval_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def modules(self) -> list[BaseRagModule]:
def run(self, context: RagContext) -> RagContext:
logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules))

results = utils.execute_futures_list(
[self.futures_executor.submit(with_contextvars(r.run), context) for r in self.retrieval_modules]
)
with self.create_futures_executor() as futures_executor:
results = utils.execute_futures_list(
[futures_executor.submit(with_contextvars(r.run), context) for r in self.retrieval_modules]
)

# flatten the list of lists
results = list(itertools.chain.from_iterable(results))
Expand Down
13 changes: 7 additions & 6 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def load_collection(
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}

return execute_futures_dict(
{
key: self.futures_executor.submit(with_contextvars(self.load), source)
for key, source in sources_by_key.items()
},
)
with self.create_futures_executor() as futures_executor:
return execute_futures_dict(
{
key: futures_executor.submit(with_contextvars(self.load), source)
for key, source in sources_by_key.items()
},
)

def to_key(self, source: S) -> str:
"""Converts the source to a key for the collection."""
Expand Down
35 changes: 23 additions & 12 deletions griptape/mixins/futures_executor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import contextlib
import warnings
from abc import ABC
from concurrent import futures
from typing import Callable
Expand All @@ -14,16 +14,27 @@ class FuturesExecutorMixin(ABC):
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)

futures_executor: futures.Executor = field(
default=Factory(lambda self: self.create_futures_executor(), takes_self=True)
_futures_executor: futures.Executor = field(
default=Factory(
lambda self: self.create_futures_executor(),
takes_self=True,
),
alias="futures_executor",
)

def __del__(self) -> None:
executor = self.futures_executor

if executor is not None:
self.futures_executor = None # pyright: ignore[reportAttributeAccessIssue] In practice this is safe, nobody will access this attribute after this point

with contextlib.suppress(Exception):
# don't raise exceptions in __del__
executor.shutdown(wait=True)
Comment on lines -21 to -29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this currently safe to remove without causing the issues that caused it to be added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that was resolved by #1074. This particular method was only introduced as a clean-up effort.

@property
def futures_executor(self) -> futures.Executor:
self.__raise_deprecation_warning()
return self._futures_executor

@futures_executor.setter
def futures_executor(self, value: futures.Executor) -> None:
self.__raise_deprecation_warning()
self._futures_executor = value

def __raise_deprecation_warning(self) -> None:
warnings.warn(
"`FuturesExecutorMixin.futures_executor` is deprecated and will be removed in a future release. Use `FuturesExecutorMixin.create_futures_executor` instead.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe specify "use FuturesExecutorMixin.create_futures_executor() as a context manager instead."? better to be explicit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you could use it not as a context manager:

executor = self.create_futures_executor()
executor.shutdown()

DeprecationWarning,
stacklevel=2,
)
27 changes: 14 additions & 13 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,24 @@ def insert_task(
def try_run(self, *args) -> Workflow:
exit_loop = False

while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()
with self.create_futures_executor() as futures_executor:
while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()

for task in ordered_tasks:
if task.can_run():
future = self.futures_executor.submit(with_contextvars(task.run))
futures_list[future] = task
for task in ordered_tasks:
if task.can_run():
future = futures_executor.submit(with_contextvars(task.run))
futures_list[future] = task

# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True
# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True

break
break

return self
return self

def context(self, task: BaseTask) -> dict[str, Any]:
context = super().context(task)
Expand Down
7 changes: 4 additions & 3 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def try_run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def run_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
return utils.execute_futures_list(
[self.futures_executor.submit(with_contextvars(self.run_action), a) for a in actions]
)
with self.create_futures_executor() as futures_executor:
return utils.execute_futures_list(
[futures_executor.submit(with_contextvars(self.run_action), a) for a in actions]
)

def run_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestBaseEventListenerDriver:
def test_publish_event_no_batched(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=False, futures_executor=executor)
driver = MockEventListenerDriver(batched=False, create_futures_executor=lambda: executor)
mock_event_payload = MockEvent().to_dict()

driver.publish_event(mock_event_payload)
Expand All @@ -18,7 +18,7 @@ def test_publish_event_no_batched(self):
def test_publish_event_yes_batched(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
driver = MockEventListenerDriver(batched=True, create_futures_executor=lambda: executor)
mock_event_payload = MockEvent().to_dict()

# Publish 9 events to fill the batch
Expand All @@ -38,7 +38,7 @@ def test_publish_event_yes_batched(self):
def test_flush_events(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
driver = MockEventListenerDriver(batched=True, create_futures_executor=lambda: executor)
driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload)

driver.flush_events()
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/mixins/test_futures_executor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from concurrent import futures

import pytest

from tests.mocks.mock_futures_executor import MockFuturesExecutor


Expand All @@ -8,3 +10,9 @@ def test_futures_executor(self):
executor = futures.ThreadPoolExecutor()

assert MockFuturesExecutor(create_futures_executor=lambda: executor).futures_executor == executor

def test_deprecated_futures_executor(self):
mock_executor = MockFuturesExecutor()
with pytest.warns(DeprecationWarning):
assert mock_executor.futures_executor
mock_executor.futures_executor = futures.ThreadPoolExecutor()
Loading