diff --git a/domino_maintenance_mode/cli.py b/domino_maintenance_mode/cli.py index 39f635d..a974e2e 100644 --- a/domino_maintenance_mode/cli.py +++ b/domino_maintenance_mode/cli.py @@ -102,6 +102,18 @@ async def _async_snapshot(output, **kwargs): json.dump(state, output) +def validate_services(ctx, param, value): + if not value: + return None + elif value not in __get_execution_interfaces().keys(): + raise click.BadParameter( + "Services must be one of " + f"{list(__get_execution_interfaces().keys())}" + ) + else: + return value + + @click.command() @click.argument("snapshot", type=click.File("r")) @click.option( @@ -138,6 +150,14 @@ async def _async_snapshot(output, **kwargs): default=600, help="Amount of time to wait for executions to complete.", ) +@click.option( + "-s", + "--service", + type=str, + help="(Optional) Service to shutdown. Options are: " + f"{list(__get_execution_interfaces().keys())}", + callback=validate_services, +) def shutdown(snapshot, **kwargs): """Stop running Apps, Model APIs, Durable Workspaces, and Scheduled Jobs. @@ -145,10 +165,16 @@ def shutdown(snapshot, **kwargs): """ state = __load_state(snapshot) manager = Manager(**kwargs) - for interface in __get_execution_interfaces().values(): + if manager.get_service(): + interface = __get_execution_interfaces()[manager.get_service()] executions = state[interface.singular()] if len(executions) > 0: manager.stop(interface, executions) + else: + for interface in __get_execution_interfaces().values(): + executions = state[interface.singular()] + if len(executions) > 0: + manager.stop(interface, executions) cli.add_command(shutdown) diff --git a/domino_maintenance_mode/manager.py b/domino_maintenance_mode/manager.py index f4305c7..105b72e 100644 --- a/domino_maintenance_mode/manager.py +++ b/domino_maintenance_mode/manager.py @@ -28,6 +28,7 @@ class Manager: def __init__( self, + service: str = None, batch_size: int = 5, batch_interval_s: int = 5, max_failures: int = 5, @@ -37,6 +38,10 @@ def __init__( self.batch_interval_s = batch_interval_s self.grace_period_s = grace_period_s self.max_failures = max_failures + self.service = service + + def get_service(self): + return self.service def stop(self, interface: ExecutionInterface, executions: List[Execution]): self.__toggle_executions(