Skip to content

Commit

Permalink
Refactor util files
Browse files Browse the repository at this point in the history
  • Loading branch information
reweeden committed Jan 22, 2022
1 parent ea4f0c1 commit e067ba8
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 212 deletions.
101 changes: 52 additions & 49 deletions rain_api_core/aws_util.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,55 @@

import functools
import json
import logging
import os
import sys
import urllib
from netaddr import IPAddress, IPNetwork
from json import loads
from time import time
from yaml import safe_load
from boto3 import client as botoclient, resource as botoresource, session as botosession, Session as boto_Session

from boto3 import Session as boto_Session
from boto3 import client as botoclient
from boto3 import resource as botoresource
from boto3 import session as botosession
from boto3.resources.base import ServiceResource
from botocore.config import Config as bc_Config
from botocore.exceptions import ClientError
from netaddr import IPAddress, IPNetwork
from yaml import safe_load

from rain_api_core.general_util import return_timing_object, duration
from rain_api_core.general_util import duration, return_timing_object

log = logging.getLogger(__name__)
sts = botoclient('sts')
secret_cache = {}
session_cache = {}
region_list_cache = []
s3_resource = None
region = ''
botosess = botosession.Session()
role_creds_cache = {os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {}, os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {}}
role_creds_cache = {
os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN'): {},
os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN'): {}
}


def get_region():
"""
Will determine and return current AWS region.
:return: string describing AWS region
:type: string
"""
global region #pylint: disable=global-statement
global botosess #pylint: disable=global-statement
global region # pylint: disable=global-statement
global botosess # pylint: disable=global-statement
if not region:
region = botosess.region_name
return region


@functools.lru_cache(maxsize=None)
def retrieve_secret(secret_name):

global secret_cache # pylint: disable=global-statement
global botosess # pylint: disable=global-statement
global region # pylint: disable=global-statement
global botosess # pylint: disable=global-statement
t0 = time()

if secret_name in secret_cache:
log.debug('ET for retrieving secret {} from cache: {} sec'.format(secret_name, round(time() - t0, 4)))
return secret_cache[secret_name]

region_name = os.getenv('AWS_DEFAULT_REGION')

# Create a Secrets Manager client
Expand All @@ -72,9 +75,7 @@ def retrieve_secret(secret_name):
# Decrypts secret using the associated KMS CMK.
# Depending on whether the secret is a string or binary, one of these fields will be populated.
if 'SecretString' in get_secret_value_response:

secret = loads(get_secret_value_response['SecretString'])
secret_cache[secret_name] = secret
secret = json.loads(get_secret_value_response['SecretString'])
log.debug('ET for retrieving secret {} from secret store: {} sec'.format(secret_name, round(time() - t0, 4)))
return secret

Expand All @@ -86,18 +87,19 @@ def get_s3_resource():
:return: subclass of boto3.resources.base.ServiceResource
"""
global s3_resource #pylint: disable=global-statement
global s3_resource # pylint: disable=global-statement
if not s3_resource:
params = {}
# Swift signature compatability
if os.getenv('S3_SIGNATURE_VERSION'):
params['config'] = bc_Config(signature_version=os.getenv('S3_SIGNATURE_VERSION'))
signature_version = os.getenv('S3_SIGNATURE_VERSION')
if signature_version:
params['config'] = bc_Config(signature_version=signature_version)
s3_resource = botoresource('s3', **params)

return s3_resource


def read_s3(bucket: str, key: str, s3: ServiceResource=None):
def read_s3(bucket: str, key: str, s3: ServiceResource = None):
"""
returns file
:type bucket: str
Expand All @@ -117,7 +119,7 @@ def read_s3(bucket: str, key: str, s3: ServiceResource=None):
obj = s3.Object(bucket, key)
log.debug('ET for reading {} from S3: {} sec'.format(key, round(time() - t0, 4)))
timer = time()
body = obj.get()['Body'].read().decode('utf-8')
body = obj.get()['Body'].read().decode('utf-8')
log.info(return_timing_object(service="s3", endpoint=f"resource().Object(s3://{bucket}/{key}).get()", duration=duration(timer)))
return body

Expand All @@ -138,7 +140,6 @@ def get_yaml(bucket: str, file_name: str):


def get_yaml_file(bucket, key):

if not key:
# No file was provided, send empty dict
return {}
Expand All @@ -152,50 +153,54 @@ def get_yaml_file(bucket, key):
# TODO(reweeden): remove this, why is this here!!!?
sys.exit()

def get_role_creds(user_id: str='', in_region: bool=False):

def get_role_creds(user_id: str = '', in_region: bool = False):
"""
:param user_id: string with URS username
:param in_region: boolean If True a download role that works only in region will be returned
:return: Returns a set of temporary security credentials (consisting of an access key ID, a secret access key, and a security token)
:return: Offset, in seconds for how long the STS session has been active
"""
global sts #pylint: disable=global-statement
global sts # pylint: disable=global-statement
if not user_id:
user_id = 'unauthenticated'

if in_region:
download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN')
else:
download_role_arn = os.getenv('EGRESS_APP_DOWNLOAD_ROLE_ARN')
dl_arn_name=download_role_arn.split("/")[-1]
dl_arn_name = download_role_arn.split("/")[-1]

# chained role assumption like this CANNOT currently be extended past 1 Hour.
# https://aws.amazon.com/premiumsupport/knowledge-center/iam-role-chaining-limit/
now = time()
session_params = {"RoleArn": download_role_arn, "RoleSessionName": f"{user_id}@{round(now)}", "DurationSeconds": 3600 }
session_params = {
"RoleArn": download_role_arn,
"RoleSessionName": f"{user_id}@{round(now)}",
"DurationSeconds": 3600
}
session_offset = 0

if user_id not in role_creds_cache[download_role_arn]:
fresh_session = sts.assume_role(**session_params)
log.info(return_timing_object(service="sts", endpoint=f"client().assume_role({dl_arn_name}/{user_id})", duration=duration(now)))
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now }
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now}
elif now - role_creds_cache[download_role_arn][user_id]["timestamp"] > 600:
# If the session has been active for more than 10 minutes, grab a new one.
log.info("Replacing 10 minute old session for {0}".format(user_id))
fresh_session = sts.assume_role(**session_params)
log.info(return_timing_object(service="sts", endpoint="client().assume_role()", duration=duration(now)))
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now }
role_creds_cache[download_role_arn][user_id] = {"session": fresh_session, "timestamp": now}
else:
log.info("Reusing role credentials for {0}".format(user_id))
session_offset = round( now - role_creds_cache[download_role_arn][user_id]["timestamp"] )
session_offset = round(now - role_creds_cache[download_role_arn][user_id]["timestamp"])

log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn,user_id))
log.debug(f'assuming role: {0}, role session username: {1}'.format(download_role_arn, user_id))
return role_creds_cache[download_role_arn][user_id]["session"], session_offset


def get_role_session(creds=None, user_id=None):

global session_cache #pylint: disable=global-statement
global session_cache # pylint: disable=global-statement
sts_resp = creds if creds else get_role_creds(user_id)[0]
log.debug('sts_resp: {0}'.format(sts_resp))

Expand All @@ -216,23 +221,21 @@ def get_region_cidr_ranges():
"""
:return: Utility function to download AWS regions
"""
global region_list_cache # pylint: disable=global-statement

global region_list_cache #pylint: disable=global-statement

if not region_list_cache: #pylint: disable=used-before-assignment
if not region_list_cache: # pylint: disable=used-before-assignment
url = 'https://ip-ranges.amazonaws.com/ip-ranges.json'
now = time()
req = urllib.request.Request(url)
r = urllib.request.urlopen(req).read() #nosec URL is *always* https://ip-ranges...
r = urllib.request.urlopen(req).read() # nosec URL is *always* https://ip-ranges...
log.info(return_timing_object(service="AWS", endpoint=url, duration=duration(now)))
region_list_json = loads(r.decode('utf-8'))
region_list_cache = []

region_list_json = json.loads(r.decode('utf-8'))
# Sort out ONLY values from this AWS region
for pre in region_list_json["prefixes"]:
if "ip_prefix" in pre and "region" in pre:
if pre["region"] == get_region():
region_list_cache.append(IPNetwork(pre["ip_prefix"]))
this_region = get_region()
region_list_cache = [
IPNetwork(pre["ip_prefix"]) for pre in region_list_json["prefixes"]
if "ip_prefix" in pre and "region" in pre and pre["region"] == this_region
]

return region_list_cache

Expand All @@ -244,9 +247,9 @@ def check_in_region_request(ip_addr: str):
:type: Boolean
"""

addr = IPAddress(ip_addr)
for cidr in get_region_cidr_ranges():
#log.debug("Checking ip {0} vs cidr {1}".format(user_ip, cidr))
if IPAddress(ip_addr) in cidr:
if addr in cidr:
log.info("IP {0} matched in-region CIDR {1}".format(ip_addr, cidr))
return True

Expand Down
58 changes: 29 additions & 29 deletions rain_api_core/egress_util.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import logging
import hmac
from hashlib import sha256
import logging
import os
import urllib
from datetime import datetime
from hashlib import sha256

log = logging.getLogger(__name__)

# This warning is stupid
# pylint: disable=logging-fstring-interpolation

def prepend_bucketname(name):

prefix = os.getenv('BUCKETNAME_PREFIX', "gsfc-ngap-{}-".format(os.getenv('MATURITY', 'DEV')[0:1].lower()))
def prepend_bucketname(name):
prefix = os.getenv('BUCKETNAME_PREFIX', "gsfc-ngap-{}-".format(os.getenv('MATURITY', 'DEV')[:1].lower()))
return "{}{}".format(prefix, name)


def hmacsha256(key, string):

return hmac.new(key, string.encode('utf-8'), sha256)


def get_presigned_url(session, bucket_name, object_name, region_name, expire_seconds, user_id, method='GET'):

timez = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
datez = timez[:8]
hostname = "{0}.s3{1}.amazonaws.com".format(bucket_name, "."+region_name if region_name != "us-east-1" else "")
hostname = "{0}.s3{1}.amazonaws.com".format(bucket_name, "." + region_name if region_name != "us-east-1" else "")

cred = session['Credentials']['AccessKeyId']
secret = session['Credentials']['SecretAccessKey']
Expand Down Expand Up @@ -53,11 +51,10 @@ def get_presigned_url(session, bucket_name, object_name, region_name, expire_sec
stringtosign = "\n".join(["AWS4-HMAC-SHA256", timez, aws4_request, can_req_hash])

# Signing Key
StepOne = hmacsha256( "AWS4{0}".format(secret).encode('utf-8'), datez).digest()
StepTwo = hmacsha256( StepOne, region_name ).digest()
StepThree = hmacsha256( StepTwo, "s3").digest()
SigningKey = hmacsha256( StepThree, "aws4_request").digest()

StepOne = hmacsha256("AWS4{0}".format(secret).encode('utf-8'), datez).digest()
StepTwo = hmacsha256(StepOne, region_name).digest()
StepThree = hmacsha256(StepTwo, "s3").digest()
SigningKey = hmacsha256(StepThree, "aws4_request").digest()

# Final Signature
Signature = hmacsha256(SigningKey, stringtosign).hexdigest()
Expand All @@ -68,7 +65,6 @@ def get_presigned_url(session, bucket_name, object_name, region_name, expire_sec


def get_bucket_dynamic_path(path_list, b_map):

# Old and REVERSE format has no 'MAP'. In either case, we don't want it fouling our dict.
if 'MAP' in b_map:
map_dict = b_map['MAP']
Expand All @@ -81,7 +77,7 @@ def get_bucket_dynamic_path(path_list, b_map):
# walk the bucket map to see if this path is valid
for path_part in path_list:
# Check if we hit a leaf of the YAML tree
if (mapping and isinstance(map_dict, str)) or 'bucket' in map_dict: #
if (mapping and isinstance(map_dict, str)) or 'bucket' in map_dict:
customheaders = {}
if isinstance(map_dict, dict) and 'bucket' in map_dict:
bucketname = map_dict['bucket']
Expand Down Expand Up @@ -130,31 +126,33 @@ def process_varargs(varargs: list, b_map: dict):


def process_request(varargs, b_map):

varargs = varargs.split("/")
split_args = varargs.split("/")

# Make sure we got at least 1 path, and 1 file name:
if len(varargs) < 2:
return "/".join(varargs), None, None, []
if len(split_args) < 2:
return varargs, None, None, {}

# Watch for ASF-ish reverse URL mapping formats:
if len(varargs) == 3:
if len(split_args) == 3:
if os.getenv('USE_REVERSE_BUCKET_MAP', 'FALSE').lower() == 'true':
varargs[0], varargs[1] = varargs[1], varargs[0]
split_args[0], split_args[1] = split_args[1], split_args[0]

# Look up the bucket from path parts
bucket, path, object_name, headers = get_bucket_dynamic_path(varargs, b_map)
bucket, path, object_name, headers = get_bucket_dynamic_path(split_args, b_map)

# If we didn't figure out the bucket, we don't know the path/object_name
if not bucket:
object_name = varargs.pop(-1)
path = "/".join(varargs)
object_name = split_args.pop(-1)
path = "/".join(split_args)

return path, bucket, object_name, headers


def bucket_prefix_match(bucket_check, bucket_map, object_name=""):
# NOTE: https://github.com/asfadmin/thin-egress-app/issues/188
log.debug(f"bucket_prefix_match(): checking if {bucket_check} matches {bucket_map} w/ optional obj '{object_name}'")
if bucket_check == bucket_map.split('/')[0] and object_name.startswith("/".join(bucket_map.split('/')[1:])):
prefix, *tail = bucket_map.split("/", 1)
if bucket_check == prefix and object_name.startswith("/".join(tail)):
log.debug(f"Prefixed Bucket Map matched: s3://{bucket_check}/{object_name} => {bucket_map}")
return True
return False
Expand All @@ -168,21 +166,22 @@ def get_sorted_bucket_list(b_map, bucket_group):
return []

# b_map[bucket_group] SHOULD be a dict, but list actually works too.
if isinstance(b_map[bucket_group], dict):
return sorted(list(b_map[bucket_group].keys()), key=lambda e: e.count("/"), reverse=True )
if isinstance(b_map[bucket_group], dict):
return sorted(list(b_map[bucket_group].keys()), key=lambda e: e.count("/"), reverse=True)
if isinstance(b_map[bucket_group], list):
return sorted(list(b_map[bucket_group]), key=lambda e: e.count("/"), reverse=True )
return sorted(list(b_map[bucket_group]), key=lambda e: e.count("/"), reverse=True)

# Something went wrong.
return []

def check_private_bucket(bucket, b_map, object_name=""):

def check_private_bucket(bucket, b_map, object_name=""):
log.debug('check_private_buckets(): bucket: {}'.format(bucket))

# Check public bucket file:
if 'PRIVATE_BUCKETS' in b_map:
# Prioritize prefixed buckets first, the deeper the better!
# TODO(reweeden): cache the sorted list (refactoring to object would be easiest)
sorted_buckets = get_sorted_bucket_list(b_map, 'PRIVATE_BUCKETS')
log.debug(f"Sorted PRIVATE buckets are {sorted_buckets}")
for priv_bucket in sorted_buckets:
Expand All @@ -192,10 +191,11 @@ def check_private_bucket(bucket, b_map, object_name=""):

return False

def check_public_bucket(bucket, b_map, object_name=""):

def check_public_bucket(bucket, b_map, object_name=""):
# Check for PUBLIC_BUCKETS in bucket map file
if 'PUBLIC_BUCKETS' in b_map:
# TODO(reweeden): cache the sorted list (refactoring to object would be easiest)
sorted_buckets = get_sorted_bucket_list(b_map, 'PUBLIC_BUCKETS')
log.debug(f"Sorted PUBLIC buckets are {sorted_buckets}")
for pub_bucket in sorted_buckets:
Expand Down
Loading

0 comments on commit e067ba8

Please sign in to comment.