Skip to content

Commit

Permalink
Fix integrations for AnyIO 4 and prepare for 2.20.0 release (#14827)
Browse files Browse the repository at this point in the history
  • Loading branch information
abrookins authored Aug 3, 2024
1 parent 15274df commit 67774d3
Show file tree
Hide file tree
Showing 26 changed files with 253 additions and 136 deletions.
1 change: 1 addition & 0 deletions .github/workflows/integration-pacakge-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
- name: Install dependencies
working-directory: src/integrations/${{ matrix.package }}
run: |
python -m pip install setuptools
python -m pip install -U uv
uv pip install --upgrade --system -e .[dev]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ def _custom_network_configuration(
if not all(conf_sn in subnet_ids for conf_sn in config_subnets):
raise ValueError(
f"Subnets {config_subnets} not found within {vpc_message}."
+ "Please check that VPC is associated with supplied subnets."
+ " Please check that VPC is associated with supplied subnets."
)

return {"awsvpcConfiguration": network_configuration}
Expand Down
3 changes: 2 additions & 1 deletion src/integrations/prefect-aws/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ classifiers = [
"Topic :: Software Development :: Libraries",
]
dependencies = [
"exceptiongroup",
"boto3>=1.24.53",
"botocore>=1.27.53",
"mypy_boto3_s3>=1.24.94",
"mypy_boto3_secretsmanager>=1.26.49",
"prefect >=2.16.4, <3.0.0",
"prefect>=2.20.0, <3.0.0",
"pyparsing>=3.1.1",
"tenacity>=8.0.0",
]
Expand Down
39 changes: 27 additions & 12 deletions src/integrations/prefect-aws/tests/test_ecs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import logging
import textwrap
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, Generator, List, Optional
from unittest.mock import MagicMock

import anyio
import pytest
import yaml
from botocore.exceptions import ClientError
from exceptiongroup import BaseExceptionGroup # novermin
from moto import mock_ec2, mock_ecs, mock_logs
from moto.ec2.utils import generate_instance_identity_document
from prefect_aws.workers.ecs_worker import ECSWorker
Expand Down Expand Up @@ -37,6 +39,19 @@
)


@contextmanager
def collapse_excgroups() -> Generator[None, None, None]:
try:
yield
except BaseException as exc: # novermin
while (
isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1
): # novermin
exc = exc.exceptions[0]

raise exc


def test_ecs_task_emits_deprecation_warning():
with pytest.warns(
PrefectDeprecationWarning,
Expand All @@ -52,7 +67,6 @@ def test_ecs_task_emits_deprecation_warning():

setup_logging()


BASE_TASK_DEFINITION_YAML = """
containerDefinitions:
- cpu: 1024
Expand Down Expand Up @@ -237,18 +251,19 @@ async def run_then_stop_task(
"""
session = task.aws_credentials.get_boto3_session()

with anyio.fail_after(20):
async with anyio.create_task_group() as tg:
identifier = await tg.start(task.run)
cluster, task_arn = parse_task_identifier(identifier)
with collapse_excgroups():
with anyio.fail_after(20):
async with anyio.create_task_group() as tg:
identifier = await tg.start(task.run)
cluster, task_arn = parse_task_identifier(identifier)

if after_start:
await after_start(task_arn)
if after_start:
await after_start(task_arn)

# Stop the task after it starts to prevent the test from running forever
tg.start_soon(
partial(stop_task, session.client("ecs"), task_arn, cluster=cluster)
)
# Stop the task after it starts to prevent the test from running forever
tg.start_soon(
partial(stop_task, session.client("ecs"), task_arn, cluster=cluster)
)

return task_arn

Expand Down
152 changes: 97 additions & 55 deletions src/integrations/prefect-aws/tests/workers/test_ecs_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import logging
from contextlib import contextmanager
from functools import partial
from typing import Any, Awaitable, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, Generator, List, Optional
from unittest.mock import ANY, MagicMock
from unittest.mock import patch as mock_patch
from uuid import uuid4
Expand All @@ -10,6 +11,7 @@
import botocore
import pytest
import yaml
from exceptiongroup import BaseExceptionGroup, ExceptionGroup # novermin
from moto import mock_ec2, mock_ecs, mock_logs
from moto.ec2.utils import generate_instance_identity_document
from pydantic import VERSION as PYDANTIC_VERSION
Expand Down Expand Up @@ -79,6 +81,19 @@ def patch_task_watch_poll_interval(monkeypatch):
)


@contextmanager
def collapse_excgroups() -> Generator[None, None, None]:
try:
yield
except BaseException as exc: # novermin
while (
isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1
): # novermin
exc = exc.exceptions[0]

raise exc


def inject_moto_patches(moto_mock, patches: Dict[str, List[Callable]]):
def injected_call(method, patch_list, *args, **kwargs):
for patch in patch_list:
Expand Down Expand Up @@ -371,18 +386,19 @@ async def run(task_status):
result = await worker.run(flow_run, configuration, task_status=task_status)
return

with anyio.fail_after(20):
async with anyio.create_task_group() as tg:
identifier = await tg.start(run)
cluster, task_arn = parse_identifier(identifier)
with collapse_excgroups():
with anyio.fail_after(20):
async with anyio.create_task_group() as tg:
identifier = await tg.start(run)
cluster, task_arn = parse_identifier(identifier)

if after_start:
await after_start(task_arn)
if after_start:
await after_start(task_arn)

# Stop the task after it starts to prevent the test from running forever
tg.start_soon(
partial(stop_task, session.client("ecs"), task_arn, cluster=cluster)
)
# Stop the task after it starts to prevent the test from running forever
tg.start_soon(
partial(stop_task, session.client("ecs"), task_arn, cluster=cluster)
)

return result

Expand Down Expand Up @@ -1132,21 +1148,21 @@ async def test_network_config_from_custom_settings_invalid_subnet(

session = aws_credentials.get_boto3_session()

with pytest.raises(
ValueError,
match=(
r"Subnets \['sn-8asdas'\] not found within VPC with ID "
+ vpc.id
+ r"\.Please check that VPC is associated with supplied subnets\."
),
):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

await run_then_stop_task(worker, configuration, flow_run)

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], ValueError)
assert (
f"Subnets ['sn-8asdas'] not found within VPC with ID {vpc.id}."
+ " Please check that VPC is associated with supplied subnets."
) in str(exc.value.exceptions[0])


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_subnets(
Expand Down Expand Up @@ -1174,21 +1190,22 @@ async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_s

session = aws_credentials.get_boto3_session()

with pytest.raises(
ValueError,
match=(
rf"Subnets \['{invalid_subnet_id}', '{subnet.id}'\] not found within VPC"
f" with ID {vpc.id}.Please check that VPC is associated with supplied"
" subnets."
),
):
with pytest.raises(ExceptionGroup) as exc:
async with ECSWorker(work_pool_name="test") as worker:
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

await run_then_stop_task(worker, configuration, flow_run)

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], ValueError)
assert (
f"Subnets ['{invalid_subnet_id}', '{subnet.id}'] not found "
f"within VPC with ID {vpc.id}. Please check that VPC is "
f"associated with supplied subnets."
) in str(exc.value.exceptions[0])


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_configure_network_requires_vpc_id(
Expand Down Expand Up @@ -1289,10 +1306,18 @@ async def test_network_config_missing_default_vpc(

configuration = await construct_configuration(aws_credentials=aws_credentials)

with pytest.raises(ValueError, match="Failed to find the default VPC"):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
original_run_task = worker._create_task_run
mock_run_task = MagicMock(side_effect=original_run_task)
worker._create_task_run = mock_run_task

await run_then_stop_task(worker, configuration, flow_run)

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], ValueError)
assert "Failed to find the default VPC" in str(exc.value.exceptions[0])


@pytest.mark.usefixtures("ecs_mocks")
async def test_network_config_from_vpc_with_no_subnets(
Expand All @@ -1307,12 +1332,16 @@ async def test_network_config_from_vpc_with_no_subnets(
vpc_id=vpc.id,
)

with pytest.raises(
ValueError, match=f"Failed to find subnets for VPC with ID {vpc.id}"
):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await run_then_stop_task(worker, configuration, flow_run)

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], ValueError)
assert f"Failed to find subnets for VPC with ID {vpc.id}" in str(
exc.value.exceptions[0]
)


@pytest.mark.usefixtures("ecs_mocks")
@pytest.mark.parametrize("launch_type", ["FARGATE", "FARGATE_SPOT"])
Expand All @@ -1327,16 +1356,17 @@ async def test_bridge_network_mode_raises_on_fargate(
template_overrides=dict(task_definition={"networkMode": "bridge"}),
)

with pytest.raises(
ValueError,
match=(
"Found network mode 'bridge' which is not compatible with launch type "
f"{launch_type!r}"
),
):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await run_then_stop_task(worker, configuration, flow_run)

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], ValueError)
assert (
"Found network mode 'bridge' which is not compatible with launch type "
f"{launch_type!r}"
) in str(exc.value.exceptions[0])


@pytest.mark.usefixtures("ecs_mocks")
async def test_stream_output(
Expand Down Expand Up @@ -2280,10 +2310,13 @@ async def test_kill_infrastructure_with_invalid_identifier(aws_credentials):
aws_credentials=aws_credentials,
)

with pytest.raises(ValueError):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await worker.kill_infrastructure(configuration, "test")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], ValueError)


@pytest.mark.usefixtures("ecs_mocks")
async def test_kill_infrastructure_with_mismatched_cluster(aws_credentials):
Expand All @@ -2292,16 +2325,17 @@ async def test_kill_infrastructure_with_mismatched_cluster(aws_credentials):
cluster="foo",
)

with pytest.raises(
InfrastructureNotAvailable,
match=(
"Cannot stop ECS task: this infrastructure block has access to cluster "
"'foo' but the task is running in cluster 'bar'."
),
):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await worker.kill_infrastructure(configuration, "bar:::task_arn")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], InfrastructureNotAvailable)
assert (
"Cannot stop ECS task: this infrastructure block has access to cluster "
"'foo' but the task is running in cluster 'bar'."
) in str(exc.value.exceptions[0])


@pytest.mark.usefixtures("ecs_mocks")
async def test_kill_infrastructure_with_cluster_that_does_not_exist(aws_credentials):
Expand All @@ -2310,13 +2344,16 @@ async def test_kill_infrastructure_with_cluster_that_does_not_exist(aws_credenti
cluster="foo",
)

with pytest.raises(
InfrastructureNotFound,
match="Cannot stop ECS task: the cluster 'foo' could not be found.",
):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await worker.kill_infrastructure(configuration, "foo::task_arn")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], InfrastructureNotFound)
assert ("Cannot stop ECS task: the cluster 'foo' could not be found.") in str(
exc.value.exceptions[0]
)


@pytest.mark.usefixtures("ecs_mocks")
async def test_kill_infrastructure_with_task_that_does_not_exist(
Expand Down Expand Up @@ -2348,13 +2385,16 @@ async def test_kill_infrastructure_with_cluster_that_has_no_tasks(aws_credential
cluster="default",
)

with pytest.raises(
InfrastructureNotFound,
match="Cannot stop ECS task: the cluster 'default' has no tasks.",
):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await worker.kill_infrastructure(configuration, "default::foo")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], InfrastructureNotFound)
assert ("Cannot stop ECS task: the cluster 'default' has no tasks.") in str(
exc.value.exceptions[0]
)


@pytest.mark.usefixtures("ecs_mocks")
async def test_kill_infrastructure_with_task_that_is_already_stopped(
Expand Down Expand Up @@ -2417,10 +2457,12 @@ async def test_retry_on_failed_task_start(
},
)

with pytest.raises(RuntimeError):
with pytest.raises(ExceptionGroup) as exc: # novermin
async with ECSWorker(work_pool_name="test") as worker:
await run_then_stop_task(worker, configuration, flow_run)

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], RuntimeError)
assert run_task_mock.call_count == 3


Expand Down
Loading

0 comments on commit 67774d3

Please sign in to comment.