Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: delete workers from non-autoscaling fleets #124

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ black == 24.4.*
moto[cloudformation,s3] == 4.2.*
mypy == 1.10.*
ruff == 0.4.*
twine == 5.0.*
twine == 5.1.*
1 change: 1 addition & 0 deletions src/deadline_test_fixtures/deadline/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> N
class Fleet:
id: str
farm: Farm
autoscaling: bool = True

@staticmethod
def create(
Expand Down
93 changes: 81 additions & 12 deletions src/deadline_test_fixtures/deadline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PosixSessionUser,
OperatingSystem,
)
from .resources import Fleet
from ..util import call_api, wait_for

LOG = logging.getLogger(__name__)
Expand All @@ -43,7 +44,7 @@ def linux_worker_command(config: DeadlineWorkerConfiguration) -> str: # pragma:
"install-deadline-worker "
+ "-y "
+ f"--farm-id {config.farm_id} "
+ f"--fleet-id {config.fleet_id} "
+ f"--fleet-id {config.fleet.id} "
+ f"--region {config.region} "
+ f"--user {config.user} "
+ f"--group {config.group} "
Expand Down Expand Up @@ -80,7 +81,7 @@ def windows_worker_command(config: DeadlineWorkerConfiguration) -> str: # pragm
"install-deadline-worker "
+ "-y "
+ f"--farm-id {config.farm_id} "
+ f"--fleet-id {config.fleet_id} "
+ f"--fleet-id {config.fleet.id} "
+ f"--region {config.region} "
+ f"--user {config.user} "
+ f"{'--allow-shutdown ' if config.allow_shutdown else ''}"
Expand Down Expand Up @@ -130,10 +131,6 @@ def stop(self) -> None:
def send_command(self, command: str) -> CommandResult:
pass

@abc.abstractproperty
def worker_id(self) -> str:
pass


@dataclass(frozen=True)
class CommandResult: # pragma: no cover
Expand Down Expand Up @@ -173,7 +170,7 @@ def __str__(self) -> str:
class DeadlineWorkerConfiguration:
operating_system: OperatingSystem
farm_id: str
fleet_id: str
fleet: Fleet
region: str
user: str
group: str
Expand Down Expand Up @@ -203,11 +200,14 @@ class EC2InstanceWorker(DeadlineWorker):
s3_client: botocore.client.BaseClient
ec2_client: botocore.client.BaseClient
ssm_client: botocore.client.BaseClient
deadline_client: botocore.client.BaseClient
configuration: DeadlineWorkerConfiguration

instance_id: Optional[str] = field(init=False, default=None)

override_ami_id: InitVar[Optional[str]] = None
worker_id: Optional[str] = None

"""
Option to override the AMI ID for the EC2 instance. The latest AL2023 is used by default.
Note that the scripting to configure the EC2 instance is only verified to work on AL2023.
Expand All @@ -225,8 +225,66 @@ def start(self) -> None:
def stop(self) -> None:
LOG.info(f"Terminating EC2 instance {self.instance_id}")
self.ec2_client.terminate_instances(InstanceIds=[self.instance_id])

self.instance_id = None

if not self.configuration.fleet.autoscaling:
try:
self.wait_until_stopped()
except TimeoutError:
LOG.warning(
f"{self.worker_id} did not transition to a STOPPED status, forcibly stopping..."
)
self.set_stopped_status()

try:
self.delete()
except botocore.exceptions.ClientError as error:
LOG.exception(f"Failed to delete worker: {error}")
raise

def delete(self):
try:
self.deadline_client.delete_worker(
farmId=self.configuration.farm_id,
fleetId=self.configuration.fleet.id,
workerId=self.worker_id,
)
LOG.info(f"{self.worker_id} has been deleted from {self.configuration.fleet.id}")
except botocore.exceptions.ClientError as error:
LOG.exception(f"Failed to delete worker: {error}")
raise

def wait_until_stopped(
self, *, max_checks: int = 25, seconds_between_checks: float = 5
) -> None:
for _ in range(max_checks):
response = self.deadline_client.get_worker(
farmId=self.configuration.farm_id,
fleetId=self.configuration.fleet.id,
workerId=self.worker_id,
)
if response["status"] == "STOPPED":
LOG.info(f"{self.worker_id} is STOPPED")
break
time.sleep(seconds_between_checks)
LOG.info(f"Waiting for {self.worker_id} to transition to STOPPED status")
else:
raise TimeoutError

def set_stopped_status(self):
LOG.info(f"Setting {self.worker_id} to STOPPED status")
try:
self.deadline_client.update_worker(
farmId=self.configuration.farm_id,
fleetId=self.configuration.fleet.id,
workerId=self.worker_id,
status="STOPPED",
)
except botocore.exceptions.ClientError as error:
LOG.exception(f"Failed to update worker status: {error}")
raise

def send_command(self, command: str) -> CommandResult:
"""Send a command via SSM to a shell on a launched EC2 instance. Once the command has fully
finished the result of the invocation is returned.
Expand All @@ -240,7 +298,7 @@ def send_command(self, command: str) -> CommandResult:
#
# If we send an SSM command then we will get an InvalidInstanceId error
# if the instance isn't in that state.
NUM_RETRIES = 20
NUM_RETRIES = 30
SLEEP_INTERVAL_S = 10
for i in range(0, NUM_RETRIES):
LOG.info(f"Sending SSM command to instance {self.instance_id}")
Expand Down Expand Up @@ -491,9 +549,20 @@ def _start_worker_agent(self) -> None: # pragma: no cover
else:
self.start_windows_worker()

@property
def worker_id(self) -> str:
cmd_result = self.send_command("cat /var/lib/deadline/worker.json | jq -r '.worker_id'")
self.worker_id = self.get_worker_id()

def get_worker_id(self) -> str:
if self.configuration.operating_system.name == "AL2023":
cmd_result = self.send_command("jq -r '.worker_id' /var/lib/deadline/worker.json")
else:
cmd_result = self.send_command(
" ; ".join(
[
"$worker=Get-Content -Raw C:\ProgramData\Amazon\Deadline\Cache\worker.json | ConvertFrom-Json",
"echo $worker.worker_id",
]
)
)
assert cmd_result.exit_code == 0, f"Failed to get Worker ID: {cmd_result}"

worker_id = cmd_result.stdout.rstrip("\n\r")
Expand Down Expand Up @@ -553,7 +622,7 @@ def start(self) -> None:
run_container_env = {
**os.environ,
"FARM_ID": self.configuration.farm_id,
"FLEET_ID": self.configuration.fleet_id,
"FLEET_ID": self.configuration.fleet.id,
"AGENT_USER": self.configuration.user,
"SHARED_GROUP": self.configuration.group,
"JOB_USER": self.configuration.job_users[0].user,
Expand Down
4 changes: 3 additions & 1 deletion src/deadline_test_fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def worker_config(

yield DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
fleet=deadline_resources.fleet,
region=region,
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
Expand Down Expand Up @@ -514,10 +514,12 @@ def worker(
ec2_client = boto3.client("ec2")
s3_client = boto3.client("s3")
ssm_client = boto3.client("ssm")
deadline_client = boto3.client("deadline")

worker = EC2InstanceWorker(
ec2_client=ec2_client,
s3_client=s3_client,
deadline_client=deadline_client,
bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name,
ssm_client=ssm_client,
override_ami_id=ami_id,
Expand Down
42 changes: 34 additions & 8 deletions test/unit/deadline/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from deadline_test_fixtures.deadline import worker as mod
from deadline_test_fixtures import (
CodeArtifactRepositoryInfo,
CommandResult,
DeadlineWorkerConfiguration,
DockerContainerWorker,
EC2InstanceWorker,
PipInstall,
CodeArtifactRepositoryInfo,
OperatingSystem,
S3Object,
Fleet,
Farm,
)


Expand Down Expand Up @@ -62,7 +64,7 @@ def region(boto_config: dict[str, str]) -> str:
def worker_config(region: str) -> DeadlineWorkerConfiguration:
return DeadlineWorkerConfiguration(
farm_id="farm-123",
fleet_id="fleet-123",
fleet=Fleet(id="fleet_123", farm=Farm(id="farm-123")),
region=region,
user="test-user",
group="test-group",
Expand Down Expand Up @@ -157,7 +159,9 @@ def worker(
s3_client=boto3.client("s3"),
ec2_client=boto3.client("ec2"),
ssm_client=boto3.client("ssm"),
deadline_client=boto3.client("deadline"),
configuration=worker_config,
worker_id="worker-7c3377ec9eba444bb51cc7da18463081",
)

@patch.object(mod, "open", mock_open(read_data="mock data".encode()))
Expand All @@ -171,6 +175,13 @@ def test_start(self, worker: EC2InstanceWorker) -> None:
patch.object(worker, "_stage_s3_bucket", return_value=s3_files) as mock_stage_s3_bucket,
patch.object(worker, "_launch_instance") as mock_launch_instance,
patch.object(worker, "_start_worker_agent") as mock_start_worker_agent,
patch.object(
worker,
"get_worker_id",
return_value=CommandResult(
exit_code=0, stdout="worker-7c3377ec9eba444bb51cc7da18463081"
),
),
):
# WHEN
worker.start()
Expand Down Expand Up @@ -240,14 +251,17 @@ def test_start_worker_agent(self) -> None:

def test_stop(self, worker: EC2InstanceWorker) -> None:
# GIVEN
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()
instance_id = worker.instance_id
assert instance_id is not None

instance = TestEC2InstanceWorker.describe_instance(instance_id)
assert instance["State"]["Name"] == "running"

# WHEN
worker.stop()

# THEN
Expand All @@ -259,7 +273,11 @@ class TestSendCommand:
def test_sends_command(self, worker: EC2InstanceWorker) -> None:
# GIVEN
cmd = 'echo "Hello world"'
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()

# WHEN
with patch.object(
Expand All @@ -277,7 +295,11 @@ def test_sends_command(self, worker: EC2InstanceWorker) -> None:
def test_retries_when_instance_not_ready(self, worker: EC2InstanceWorker) -> None:
# GIVEN
cmd = 'echo "Hello world"'
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()
real_send_command = worker.ssm_client.send_command

call_count = 0
Expand Down Expand Up @@ -311,7 +333,11 @@ def side_effect(*args, **kwargs):
def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None:
# GIVEN
cmd = 'echo "Hello world"'
worker.start()
# WHEN
with patch.object(
worker, "get_worker_id", return_value="worker-7c3377ec9eba444bb51cc7da18463081"
):
worker.start()
err = ClientError({"Error": {"Code": "SomethingWentWrong"}}, "SendCommand")

# WHEN
Expand All @@ -337,7 +363,7 @@ def test_raises_any_other_error(self, worker: EC2InstanceWorker) -> None:
"worker-7c3377ec9eba444bb51cc7da18463081\r\n",
],
)
def test_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None:
def test_get_worker_id(self, worker_id: str, worker: EC2InstanceWorker) -> None:
# GIVEN
with patch.object(
worker, "send_command", return_value=CommandResult(exit_code=0, stdout=worker_id)
Expand Down