Skip to content

Commit

Permalink
Refactor S3 client initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
tremble committed Jan 19, 2023
1 parent 9733031 commit 974c4e8
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ class Connection(ConnectionBase):
has_pipelining = False
is_windows = False
_client = None
_s3_client = None
_session = None
_stdout = None
_session_id = ''
Expand All @@ -323,6 +324,30 @@ def _vvv(self, message):
def _vvvv(self, message):
self._display(display.vvvv, message)

def _init_clients(self):
self._vvvv("INITIALIZE BOTO3 CLIENTS")
profile_name = self.get_option('profile') or ''
region_name = self.get_option('region')

# The SSM Boto client, currently used to initiate and manage the session
# Note: does not handle the actual SSM session traffic
self._vvvv("SETUP BOTO3 CLIENTS: SSM")
ssm_client = self._get_boto_client('ssm', region_name=region_name, profile_name=profile_name)
self._client = ssm_client

region_name = self.get_option('region') or 'us-east-1'
self._vvvv("SETUP BOTO3 CLIENTS: S3 (tmp)")
tmp_s3_client = self._get_boto_client('s3', region_name=region_name, profile_name=profile_name)
# Fetch the location of the bucket so we can open a client against the 'right' endpoint
bucket_location = tmp_s3_client.get_bucket_location(
Bucket=(self.get_option('bucket_name')),
)
bucket_region = bucket_location['LocationConstraint']
# This is the S3 client we'll really be using
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 - {bucket_region}")
s3_bucket_client = self._get_boto_client('s3', region_name=bucket_region, profile_name=profile_name)
self._s3_client = s3_bucket_client

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -372,27 +397,26 @@ def start_session(self):
if not os.path.exists(to_bytes(executable, errors='surrogate_or_strict')):
raise AnsibleError(f"failed to find the executable specified {executable}.")

profile_name = self.get_option('profile') or ''
region_name = self.get_option('region')
ssm_parameters = dict()
client = self._get_boto_client('ssm', region_name=region_name, profile_name=profile_name)
self._client = client
self._init_clients()

self._vvvv(f"START SSM SESSION: {self.instance_id}")
start_session_args = dict(Target=self.instance_id, Parameters=ssm_parameters)
start_session_args = dict(Target=self.instance_id, Parameters={})
document_name = self.get_option('ssm_document')
if document_name is not None:
start_session_args['DocumentName'] = document_name
response = client.start_session(**start_session_args)
response = self._client.start_session(**start_session_args)
self._session_id = response['SessionId']

region_name = self.get_option('region')
profile_name = self.get_option('profile') or ''
cmd = [
executable,
json.dumps(response),
region_name,
"StartSession",
profile_name,
json.dumps({"Target": self.instance_id}),
client.meta.endpoint_url,
self._client.meta.endpoint_url,
]

self._vvvv(f"SSM COMMAND: {to_text(cmd)}")
Expand Down Expand Up @@ -650,14 +674,7 @@ def _flush_stderr(self, session_process):
def _get_url(self, client_method, bucket_name, out_path, http_method, profile_name, extra_args=None):
''' Generate URL for get_object / put_object '''

region_name = self.get_option('region') or 'us-east-1'

bucket_location = self._get_boto_client('s3', region_name=region_name, profile_name=profile_name).get_bucket_location(
Bucket=(self.get_option('bucket_name')),
)
bucket_region_name = bucket_location['LocationConstraint']

client = self._get_boto_client('s3', region_name=bucket_region_name, profile_name=profile_name)
client = self._s3_client
params = {'Bucket': bucket_name, 'Key': out_path}
if extra_args is not None:
params.update(extra_args)
Expand Down Expand Up @@ -731,7 +748,7 @@ def _file_transport_command(self, in_path, out_path, ssm_action):
get_command = "curl '%s' -o '%s'" % (
self._get_url('get_object', self.get_option('bucket_name'), s3_path, 'GET', profile_name), out_path)

client = self._get_boto_client('s3', profile_name=profile_name)
client = self._s3_client
if ssm_action == 'get':
(returncode, stdout, stderr) = self.exec_command(put_command, in_data=None, sudoable=False)
with open(to_bytes(out_path, errors='surrogate_or_strict'), 'wb') as data:
Expand Down

0 comments on commit 974c4e8

Please sign in to comment.