Skip to content

Commit

Permalink
Update waiters to use the new get_waiter_model
Browse files Browse the repository at this point in the history
This fixes an issue where previously we were creating a client
without specifying a region, which in certain cases would trigger
an error.  Now we only use the waiter model to generate the necessary
commands.

As part of thie change I went ahead and just switched the invoke
to just use the service/operation object.  This isn't ideal, but at the
same time, it's pretty clear that for the time being there's no way
to do this entirely with just clients.  Given that's the case, I don't
feel it's that imperative to use clients where possible.  When
we switch to clients, this module will need updates regardless.
It also simplified the code a little bit.
  • Loading branch information
jamesls committed Nov 10, 2014
1 parent 40bca40 commit b89d6f0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 69 deletions.
61 changes: 33 additions & 28 deletions awscli/customizations/waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from botocore import xform_name
from botocore.exceptions import DataNotFoundError

from awscli.clidriver import ServiceOperation
from awscli.customizations.commands import BasicCommand, BasicHelp, \
Expand All @@ -28,29 +29,34 @@ def add_waiters(command_table, session, command_object, **kwargs):
service_object = getattr(command_object, 'service_object', None)
if service_object is not None:
# Get a client out of the service object.
client = translate_service_object_to_client(service_object)
# Find all of the waiters for that client.
waiters = client.waiter_names
waiter_model = get_waiter_model_from_service_object(service_object)
if waiter_model is None:
return
waiter_names = waiter_model.waiter_names
# If there are waiters make a wait command.
if waiters:
command_table['wait'] = WaitCommand(client, service_object)
if waiter_names:
command_table['wait'] = WaitCommand(waiter_model, service_object)


def translate_service_object_to_client(service_object):
# Create a client from a service object.
def get_waiter_model_from_service_object(service_object):
session = service_object.session
return session.create_client(service_object.service_name)
try:
model = session.get_waiter_model(service_object.service_name,
service_object.api_version)

This comment has been minimized.

Copy link
@kyleknap

kyleknap Nov 10, 2014

Wrong indentation. service_object.api_version not under parenthesis.

except DataNotFoundError:
return None
return model


class WaitCommand(BasicCommand):
NAME = 'wait'
DESCRIPTION = 'Wait until a particular condition is satisfied.'

def __init__(self, client, service_object):
self._client = client
def __init__(self, waiter_model, service_object):
self._model = waiter_model
self._service_object = service_object
self.waiter_cmd_builder = WaiterStateCommandBuilder(
client=self._client,
model=self._model,
service_object=self._service_object
)
super(WaitCommand, self).__init__(self._service_object.session)
Expand All @@ -73,8 +79,8 @@ def create_help_command(self):


class WaiterStateCommandBuilder(object):
def __init__(self, client, service_object):
self._client = client
def __init__(self, model, service_object):
self._model = model
self._service_object = service_object

def build_all_waiter_state_cmds(self, subcommand_table):
Expand All @@ -83,22 +89,22 @@ def build_all_waiter_state_cmds(self, subcommand_table):
This is the method that adds waiter state commands like
``instance-running`` to ``ec2 wait``.
"""
waiters = self._client.waiter_names
for waiter_name in waiters:
waiter_cli_name = waiter_name.replace('_', '-')
waiter_names = self._model.waiter_names
for waiter_name in waiter_names:
waiter_cli_name = xform_name(waiter_name, '-')
subcommand_table[waiter_cli_name] = \
self._build_waiter_state_cmd(waiter_name)

def _build_waiter_state_cmd(self, waiter_name):
# Get the waiter
waiter = self._client.get_waiter(waiter_name)
waiter_config = self._model.get_waiter(waiter_name)

# Create the cli name for the waiter operation
waiter_cli_name = waiter_name.replace('_', '-')
waiter_cli_name = xform_name(waiter_name, '-')

# Obtain the name of the service operation that is used to implement
# the specified waiter.
operation_name = waiter.config.operation
operation_name = waiter_config.operation

# Create an operation object to make a command for the waiter. The
# operation object is used to generate the arguments for the waiter
Expand All @@ -107,13 +113,13 @@ def _build_waiter_state_cmd(self, waiter_name):
waiter_state_command = WaiterStateCommand(
name=waiter_cli_name, parent_name='wait',
operation_object=operation_object,
operation_caller=WaiterCaller(self._client, waiter_name),
operation_caller=WaiterCaller(waiter_name),
service_object=self._service_object
)
# Build the top level description for the waiter state command.
# Most waiters do not have a description so they need to be generated
# using the waiter configuration.
waiter_state_doc_builder = WaiterStateDocBuilder(waiter.config)
waiter_state_doc_builder = WaiterStateDocBuilder(waiter_config)
description = waiter_state_doc_builder.build_waiter_state_description()
waiter_state_command.DESCRIPTION = description
return waiter_state_command
Expand Down Expand Up @@ -172,20 +178,19 @@ def _build_operation_description(self, operation):


class WaiterCaller(object):
def __init__(self, client, waiter_name):
self._client = client
def __init__(self, waiter_name):
self._waiter_name = waiter_name

def invoke(self, operation_object, parameters, parsed_globals):
# Create the endpoint based on the parsed globals
endpoint = operation_object.service.get_endpoint(
service_object = operation_object.service
endpoint = service_object.get_endpoint(
region_name=parsed_globals.region,
endpoint_url=parsed_globals.endpoint_url,
verify=parsed_globals.verify_ssl)
# Make a clone of the client using the newly configured endpoint
client = self._client.clone_client(endpoint=endpoint)
# Make the waiter and call its wait method.
client.get_waiter(self._waiter_name).wait(**parameters)
waiter = service_object.get_waiter(
self._waiter_name, endpoint)
waiter.wait(**parameters)
return 0


Expand Down
116 changes: 75 additions & 41 deletions tests/unit/customizations/test_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
# language governing permissions and limitations under the License.
import mock

from botocore.waiter import WaiterModel
from botocore.exceptions import DataNotFoundError

from awscli.testutils import unittest, BaseAWSHelpOutputTest, \
BaseAWSCommandParamsTest
from awscli.customizations.waiters import add_waiters, WaitCommand, \
translate_service_object_to_client, WaiterStateCommand, WaiterCaller, \
get_waiter_model_from_service_object, WaiterStateCommand, WaiterCaller, \
WaiterStateDocBuilder, WaiterStateCommandBuilder


class TestAddWaiters(unittest.TestCase):
def setUp(self):
self.service_object = mock.Mock()
self.session = mock.Mock()
self.client = mock.Mock()

self.command_object = mock.Mock()
self.command_object.service_object = self.service_object
Expand All @@ -32,10 +34,15 @@ def setUp(self):
self.service_object.session = self.session

# Set up the mock session.
self.session.create_client.return_value = self.client
self.session.get_waiter_model.return_value = WaiterModel(
{
'version': 2,
'waiters': {
'FooExists': {},
}
}
)

# Set up the mock client.
self.client.waiter_names = ['waiter']

def test_add_waiters(self):
command_table = {}
Expand All @@ -45,7 +52,13 @@ def test_add_waiters(self):
self.assertIsInstance(command_table['wait'], WaitCommand)

def test_add_waiters_no_waiter_names(self):
self.client.waiter_names = []
self.session.get_waiter_model.return_value = WaiterModel(
{
'version': 2,
# No waiters are specified.
'waiters': {}
}
)
command_table = {}
add_waiters(command_table, self.session, self.command_object)
# Make sure that no wait command was added since the service object
Expand All @@ -60,22 +73,49 @@ def test_add_waiters_no_service_object(self):
# was passed in.
self.assertEqual(command_table, {})

def test_add_waiter_no_waiter_config(self):
self.session.get_waiter_model.side_effect = DataNotFoundError(
data_path='foo')
command_table = {}
add_waiters(command_table, self.session, self.command_object)
self.assertEqual(command_table, {})


class TestServicetoWaiterModel(unittest.TestCase):
def test_service_object_to_waiter_model(self):
service_object = mock.Mock()
session = mock.Mock()
service_object.session = session
service_object.service_name = 'service'
service_object.api_version = '2014-01-01'
get_waiter_model_from_service_object(service_object)
session.get_waiter_model.assert_called_with('service', '2014-01-01')

class TestTranslateServiceObjectToClient(unittest.TestCase):
def test_translate_service_object_to_client(self):
def test_can_handle_data_errors(self):
service_object = mock.Mock()
session = mock.Mock()
service_object.session = session
service_object.service_name = 'service'
translate_service_object_to_client(service_object)
session.create_client.assert_called_with('service')
service_object.api_version = '2014-01-01'
session.get_waiter_model.side_effect = DataNotFoundError(
data_path='foo')
self.assertIsNone(
get_waiter_model_from_service_object(service_object))


class TestWaitCommand(unittest.TestCase):
def setUp(self):
self.client = mock.Mock()
self.model = WaiterModel({
'version': 2,
'waiters': {
'Foo': {
'operation': 'foo', 'maxAttempts': 1, 'delay': 1,
'acceptors': [],
}
}
})
self.service_object = mock.Mock()
self.cmd = WaitCommand(self.client, self.service_object)
self.cmd = WaitCommand(self.model, self.service_object)

def test_run_main_error(self):
self.parsed_args = mock.Mock()
Expand Down Expand Up @@ -146,26 +186,29 @@ def test_elastictranscoder_jobs_complete(self):

class TestWaiterStateCommandBuilder(unittest.TestCase):
def setUp(self):
self.client = mock.Mock()
self.service_object = mock.Mock()

# Create some waiters.
self.client.waiter_names = ['instance_running', 'bucket_exists']
self.instance_running_waiter = mock.Mock()
self.bucket_exists_waiter = mock.Mock()

# Make a mock waiter config.
self.waiter_config = mock.Mock()
self.waiter_config.operation = 'MyOperation'
self.waiter_config.description = 'my waiter description'
self.instance_running_waiter.config = self.waiter_config
self.bucket_exists_waiter.config = self.waiter_config

self.client.get_waiter.side_effect = [
self.instance_running_waiter, self.bucket_exists_waiter]
self.model = WaiterModel({
'version': 2,
'waiters': {
'InstanceRunning': {
'description': 'my waiter description',
'delay': 1,
'maxAttempts': 10,
'operation': 'MyOperation',
},
'BucketExists': {
'description': 'my waiter description',
'operation': 'MyOperation',
'delay': 1,
'maxAttempts': 10,
}
}
})

self.waiter_builder = WaiterStateCommandBuilder(
self.client,
self.model,
self.service_object
)

Expand All @@ -191,11 +234,11 @@ def test_build_waiter_state_cmds(self):
# Check the descriptions are set correctly.
self.assertEqual(
instance_running_cmd.DESCRIPTION,
self.waiter_config.description
'my waiter description',
)
self.assertEqual(
bucket_exists_cmd.DESCRIPTION,
self.waiter_config.description
'my waiter description',
)


Expand Down Expand Up @@ -282,35 +325,26 @@ def test_path_any_acceptor(self):

class TestWaiterCaller(unittest.TestCase):
def test_invoke(self):
client = mock.Mock()
waiter = mock.Mock()
waiter_name = 'my_waiter'
operation_object = mock.Mock()

# Mock the clone of the client
cloned_client = mock.Mock()
cloned_client.get_waiter.return_value = waiter
client.clone_client.return_value = cloned_client
operation_object.service.get_waiter.return_value = waiter

parameters = {'Foo': 'bar', 'Baz': 'biz'}
parsed_globals = mock.Mock()
parsed_globals.region = 'us-east-1'
parsed_globals.endpoint_url = 'myurl'
parsed_globals.verify_ssl = True

waiter_caller = WaiterCaller(client, waiter_name)
waiter_caller = WaiterCaller(waiter_name)
waiter_caller.invoke(operation_object, parameters, parsed_globals)
# Make sure the endpoint was created properly
operation_object.service.get_endpoint.assert_called_with(
region_name=parsed_globals.region,
endpoint_url=parsed_globals.endpoint_url,
verify=parsed_globals.verify_ssl
)
# Ensure the client was cloned with using the new endpoint.
clone_kwargs = client.clone_client.call_args[1]
self.assertIn('endpoint', clone_kwargs)
# Ensure we get the waiter.
cloned_client.get_waiter.assert_called_with(waiter_name)

# Ensure the wait command was called properly.
waiter.wait.assert_called_with(
Foo='bar', Baz='biz')
Expand Down

0 comments on commit b89d6f0

Please sign in to comment.