diff --git a/awscli/customizations/waiters.py b/awscli/customizations/waiters.py index c54b45f5fe1c..9440beb8e42b 100644 --- a/awscli/customizations/waiters.py +++ b/awscli/customizations/waiters.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from botocore import xform_name +from botocore.exceptions import DataNotFoundError from awscli.clidriver import ServiceOperation from awscli.customizations.commands import BasicCommand, BasicHelp, \ @@ -28,29 +29,34 @@ def add_waiters(command_table, session, command_object, **kwargs): service_object = getattr(command_object, 'service_object', None) if service_object is not None: # Get a client out of the service object. - client = translate_service_object_to_client(service_object) - # Find all of the waiters for that client. - waiters = client.waiter_names + waiter_model = get_waiter_model_from_service_object(service_object) + if waiter_model is None: + return + waiter_names = waiter_model.waiter_names # If there are waiters make a wait command. - if waiters: - command_table['wait'] = WaitCommand(client, service_object) + if waiter_names: + command_table['wait'] = WaitCommand(waiter_model, service_object) -def translate_service_object_to_client(service_object): - # Create a client from a service object. +def get_waiter_model_from_service_object(service_object): session = service_object.session - return session.create_client(service_object.service_name) + try: + model = session.get_waiter_model(service_object.service_name, + service_object.api_version) + except DataNotFoundError: + return None + return model class WaitCommand(BasicCommand): NAME = 'wait' DESCRIPTION = 'Wait until a particular condition is satisfied.' - def __init__(self, client, service_object): - self._client = client + def __init__(self, waiter_model, service_object): + self._model = waiter_model self._service_object = service_object self.waiter_cmd_builder = WaiterStateCommandBuilder( - client=self._client, + model=self._model, service_object=self._service_object ) super(WaitCommand, self).__init__(self._service_object.session) @@ -73,8 +79,8 @@ def create_help_command(self): class WaiterStateCommandBuilder(object): - def __init__(self, client, service_object): - self._client = client + def __init__(self, model, service_object): + self._model = model self._service_object = service_object def build_all_waiter_state_cmds(self, subcommand_table): @@ -83,22 +89,22 @@ def build_all_waiter_state_cmds(self, subcommand_table): This is the method that adds waiter state commands like ``instance-running`` to ``ec2 wait``. """ - waiters = self._client.waiter_names - for waiter_name in waiters: - waiter_cli_name = waiter_name.replace('_', '-') + waiter_names = self._model.waiter_names + for waiter_name in waiter_names: + waiter_cli_name = xform_name(waiter_name, '-') subcommand_table[waiter_cli_name] = \ self._build_waiter_state_cmd(waiter_name) def _build_waiter_state_cmd(self, waiter_name): # Get the waiter - waiter = self._client.get_waiter(waiter_name) + waiter_config = self._model.get_waiter(waiter_name) # Create the cli name for the waiter operation - waiter_cli_name = waiter_name.replace('_', '-') + waiter_cli_name = xform_name(waiter_name, '-') # Obtain the name of the service operation that is used to implement # the specified waiter. - operation_name = waiter.config.operation + operation_name = waiter_config.operation # Create an operation object to make a command for the waiter. The # operation object is used to generate the arguments for the waiter @@ -107,13 +113,13 @@ def _build_waiter_state_cmd(self, waiter_name): waiter_state_command = WaiterStateCommand( name=waiter_cli_name, parent_name='wait', operation_object=operation_object, - operation_caller=WaiterCaller(self._client, waiter_name), + operation_caller=WaiterCaller(waiter_name), service_object=self._service_object ) # Build the top level description for the waiter state command. # Most waiters do not have a description so they need to be generated # using the waiter configuration. - waiter_state_doc_builder = WaiterStateDocBuilder(waiter.config) + waiter_state_doc_builder = WaiterStateDocBuilder(waiter_config) description = waiter_state_doc_builder.build_waiter_state_description() waiter_state_command.DESCRIPTION = description return waiter_state_command @@ -172,20 +178,19 @@ def _build_operation_description(self, operation): class WaiterCaller(object): - def __init__(self, client, waiter_name): - self._client = client + def __init__(self, waiter_name): self._waiter_name = waiter_name def invoke(self, operation_object, parameters, parsed_globals): # Create the endpoint based on the parsed globals - endpoint = operation_object.service.get_endpoint( + service_object = operation_object.service + endpoint = service_object.get_endpoint( region_name=parsed_globals.region, endpoint_url=parsed_globals.endpoint_url, verify=parsed_globals.verify_ssl) - # Make a clone of the client using the newly configured endpoint - client = self._client.clone_client(endpoint=endpoint) - # Make the waiter and call its wait method. - client.get_waiter(self._waiter_name).wait(**parameters) + waiter = service_object.get_waiter( + self._waiter_name, endpoint) + waiter.wait(**parameters) return 0 diff --git a/tests/unit/customizations/test_waiters.py b/tests/unit/customizations/test_waiters.py index 73a6dea63497..aa08f3bfc88b 100644 --- a/tests/unit/customizations/test_waiters.py +++ b/tests/unit/customizations/test_waiters.py @@ -12,10 +12,13 @@ # language governing permissions and limitations under the License. import mock +from botocore.waiter import WaiterModel +from botocore.exceptions import DataNotFoundError + from awscli.testutils import unittest, BaseAWSHelpOutputTest, \ BaseAWSCommandParamsTest from awscli.customizations.waiters import add_waiters, WaitCommand, \ - translate_service_object_to_client, WaiterStateCommand, WaiterCaller, \ + get_waiter_model_from_service_object, WaiterStateCommand, WaiterCaller, \ WaiterStateDocBuilder, WaiterStateCommandBuilder @@ -23,7 +26,6 @@ class TestAddWaiters(unittest.TestCase): def setUp(self): self.service_object = mock.Mock() self.session = mock.Mock() - self.client = mock.Mock() self.command_object = mock.Mock() self.command_object.service_object = self.service_object @@ -32,10 +34,15 @@ def setUp(self): self.service_object.session = self.session # Set up the mock session. - self.session.create_client.return_value = self.client + self.session.get_waiter_model.return_value = WaiterModel( + { + 'version': 2, + 'waiters': { + 'FooExists': {}, + } + } + ) - # Set up the mock client. - self.client.waiter_names = ['waiter'] def test_add_waiters(self): command_table = {} @@ -45,7 +52,13 @@ def test_add_waiters(self): self.assertIsInstance(command_table['wait'], WaitCommand) def test_add_waiters_no_waiter_names(self): - self.client.waiter_names = [] + self.session.get_waiter_model.return_value = WaiterModel( + { + 'version': 2, + # No waiters are specified. + 'waiters': {} + } + ) command_table = {} add_waiters(command_table, self.session, self.command_object) # Make sure that no wait command was added since the service object @@ -60,22 +73,49 @@ def test_add_waiters_no_service_object(self): # was passed in. self.assertEqual(command_table, {}) + def test_add_waiter_no_waiter_config(self): + self.session.get_waiter_model.side_effect = DataNotFoundError( + data_path='foo') + command_table = {} + add_waiters(command_table, self.session, self.command_object) + self.assertEqual(command_table, {}) + + +class TestServicetoWaiterModel(unittest.TestCase): + def test_service_object_to_waiter_model(self): + service_object = mock.Mock() + session = mock.Mock() + service_object.session = session + service_object.service_name = 'service' + service_object.api_version = '2014-01-01' + get_waiter_model_from_service_object(service_object) + session.get_waiter_model.assert_called_with('service', '2014-01-01') -class TestTranslateServiceObjectToClient(unittest.TestCase): - def test_translate_service_object_to_client(self): + def test_can_handle_data_errors(self): service_object = mock.Mock() session = mock.Mock() service_object.session = session service_object.service_name = 'service' - translate_service_object_to_client(service_object) - session.create_client.assert_called_with('service') + service_object.api_version = '2014-01-01' + session.get_waiter_model.side_effect = DataNotFoundError( + data_path='foo') + self.assertIsNone( + get_waiter_model_from_service_object(service_object)) class TestWaitCommand(unittest.TestCase): def setUp(self): - self.client = mock.Mock() + self.model = WaiterModel({ + 'version': 2, + 'waiters': { + 'Foo': { + 'operation': 'foo', 'maxAttempts': 1, 'delay': 1, + 'acceptors': [], + } + } + }) self.service_object = mock.Mock() - self.cmd = WaitCommand(self.client, self.service_object) + self.cmd = WaitCommand(self.model, self.service_object) def test_run_main_error(self): self.parsed_args = mock.Mock() @@ -146,26 +186,29 @@ def test_elastictranscoder_jobs_complete(self): class TestWaiterStateCommandBuilder(unittest.TestCase): def setUp(self): - self.client = mock.Mock() self.service_object = mock.Mock() # Create some waiters. - self.client.waiter_names = ['instance_running', 'bucket_exists'] - self.instance_running_waiter = mock.Mock() - self.bucket_exists_waiter = mock.Mock() - - # Make a mock waiter config. - self.waiter_config = mock.Mock() - self.waiter_config.operation = 'MyOperation' - self.waiter_config.description = 'my waiter description' - self.instance_running_waiter.config = self.waiter_config - self.bucket_exists_waiter.config = self.waiter_config - - self.client.get_waiter.side_effect = [ - self.instance_running_waiter, self.bucket_exists_waiter] + self.model = WaiterModel({ + 'version': 2, + 'waiters': { + 'InstanceRunning': { + 'description': 'my waiter description', + 'delay': 1, + 'maxAttempts': 10, + 'operation': 'MyOperation', + }, + 'BucketExists': { + 'description': 'my waiter description', + 'operation': 'MyOperation', + 'delay': 1, + 'maxAttempts': 10, + } + } + }) self.waiter_builder = WaiterStateCommandBuilder( - self.client, + self.model, self.service_object ) @@ -191,11 +234,11 @@ def test_build_waiter_state_cmds(self): # Check the descriptions are set correctly. self.assertEqual( instance_running_cmd.DESCRIPTION, - self.waiter_config.description + 'my waiter description', ) self.assertEqual( bucket_exists_cmd.DESCRIPTION, - self.waiter_config.description + 'my waiter description', ) @@ -282,15 +325,10 @@ def test_path_any_acceptor(self): class TestWaiterCaller(unittest.TestCase): def test_invoke(self): - client = mock.Mock() waiter = mock.Mock() waiter_name = 'my_waiter' operation_object = mock.Mock() - - # Mock the clone of the client - cloned_client = mock.Mock() - cloned_client.get_waiter.return_value = waiter - client.clone_client.return_value = cloned_client + operation_object.service.get_waiter.return_value = waiter parameters = {'Foo': 'bar', 'Baz': 'biz'} parsed_globals = mock.Mock() @@ -298,7 +336,7 @@ def test_invoke(self): parsed_globals.endpoint_url = 'myurl' parsed_globals.verify_ssl = True - waiter_caller = WaiterCaller(client, waiter_name) + waiter_caller = WaiterCaller(waiter_name) waiter_caller.invoke(operation_object, parameters, parsed_globals) # Make sure the endpoint was created properly operation_object.service.get_endpoint.assert_called_with( @@ -306,11 +344,7 @@ def test_invoke(self): endpoint_url=parsed_globals.endpoint_url, verify=parsed_globals.verify_ssl ) - # Ensure the client was cloned with using the new endpoint. - clone_kwargs = client.clone_client.call_args[1] - self.assertIn('endpoint', clone_kwargs) - # Ensure we get the waiter. - cloned_client.get_waiter.assert_called_with(waiter_name) + # Ensure the wait command was called properly. waiter.wait.assert_called_with( Foo='bar', Baz='biz')